mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-09 20:39:25 -03:00
fix(recipe): use resources type field to identify checkpoint instead of modelVersionIds[0]
When importing a CivitAI image as a recipe, modelVersionIds[0] was blindly used as the checkpoint version ID. This array mixes checkpoints and LoRAs without ordering guarantees, causing LoRAs to be saved as the recipe checkpoint. Fix by: 1. Removing the modelVersionIds[0] fallback in _download_remote_media 2. Parsing resources entries with type:"model" as the checkpoint 3. Adding model type validation in populate_checkpoint_from_civitai Also add 2 tests for the new behavior and fix 3 tests whose mocks lacked the required model.type field.
This commit is contained in:
@@ -298,3 +298,113 @@ async def test_parse_metadata_handles_modelVersionIds(monkeypatch):
|
||||
assert lora2["type"] == "lora"
|
||||
assert lora2["hash"] == "aabbccdd0022"
|
||||
assert lora2["baseModel"] == "SDXL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_metadata_extracts_checkpoint_from_resources_model_type(monkeypatch):
|
||||
"""resources entries with type:"model" should be captured as the checkpoint,
|
||||
not skipped (which was the old buggy behavior), and not mixed into loras."""
|
||||
captured_hashes = []
|
||||
|
||||
async def fake_metadata_provider():
|
||||
class Provider:
|
||||
async def get_model_by_hash(self, model_hash):
|
||||
captured_hashes.append(model_hash)
|
||||
if model_hash == "a1b2c3d4e5":
|
||||
return ({
|
||||
"id": 999,
|
||||
"modelId": 888,
|
||||
"name": "v1.0",
|
||||
"model": {"name": "Real Checkpoint", "type": "Checkpoint"},
|
||||
"baseModel": "SDXL 1.0",
|
||||
"images": [{"url": "https://image.civitai.com/cp/original=true"}],
|
||||
"files": [{"type": "Model", "primary": True, "sizeKB": 1024, "name": "cp.safetensors"}]
|
||||
}, None)
|
||||
return None, "Model not found"
|
||||
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.recipes.parsers.civitai_image.get_default_metadata_provider",
|
||||
fake_metadata_provider,
|
||||
)
|
||||
|
||||
parser = CivitaiApiMetadataParser()
|
||||
|
||||
metadata = {
|
||||
"prompt": "test",
|
||||
"resources": [
|
||||
{"hash": "a1b2c3d4e5", "name": "Real Checkpoint", "type": "model"},
|
||||
{"hash": "f6g7h8i9j0", "name": "Some LoRA", "type": "lora", "weight": 0.8},
|
||||
],
|
||||
"Model hash": "a1b2c3d4e5",
|
||||
}
|
||||
|
||||
result = await parser.parse_metadata(metadata)
|
||||
|
||||
# The type:"model" resource should be in result["model"], not in result["loras"]
|
||||
assert result["model"] is not None, "checkpoint model should be extracted"
|
||||
assert result["model"]["name"] == "Real Checkpoint"
|
||||
assert result["model"]["hash"] == "a1b2c3d4e5"
|
||||
assert result["model"]["type"] == "model"
|
||||
|
||||
# The LoRA resource should be in result["loras"]
|
||||
assert len(result["loras"]) == 1
|
||||
assert result["loras"][0]["name"] == "Some LoRA"
|
||||
|
||||
# The checkpoint hash should have triggered a lookup
|
||||
assert "a1b2c3d4e5" in captured_hashes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_metadata_resources_model_type_does_not_duplicate_checkpoint_in_loras(monkeypatch):
|
||||
"""When a resources entry has type:"model", it should NOT also appear in loras.
|
||||
Regression test for the bug where the checkpoint model appeared in both places."""
|
||||
async def fake_metadata_provider():
|
||||
class Provider:
|
||||
async def get_model_by_hash(self, model_hash):
|
||||
if model_hash == "cp123hash":
|
||||
return ({
|
||||
"id": 100,
|
||||
"modelId": 200,
|
||||
"name": "v2",
|
||||
"model": {"name": "My Checkpoint", "type": "Checkpoint"},
|
||||
"baseModel": "SDXL",
|
||||
"files": [{"type": "Model", "primary": True, "sizeKB": 1024, "name": "cp.safetensors"}]
|
||||
}, None)
|
||||
if model_hash == "lora1hash":
|
||||
return ({
|
||||
"id": 300,
|
||||
"modelId": 400,
|
||||
"name": "v1",
|
||||
"model": {"name": "Style LoRA", "type": "LORA"},
|
||||
"baseModel": "SDXL",
|
||||
"files": [{"type": "Model", "primary": True, "sizeKB": 512, "name": "style.safetensors"}]
|
||||
}, None)
|
||||
return None, "Model not found"
|
||||
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.recipes.parsers.civitai_image.get_default_metadata_provider",
|
||||
fake_metadata_provider,
|
||||
)
|
||||
|
||||
parser = CivitaiApiMetadataParser()
|
||||
metadata = {
|
||||
"resources": [
|
||||
{"hash": "cp123hash", "name": "My Checkpoint", "type": "model"},
|
||||
{"hash": "lora1hash", "name": "Style LoRA", "type": "lora", "weight": 0.5},
|
||||
],
|
||||
}
|
||||
|
||||
result = await parser.parse_metadata(metadata)
|
||||
|
||||
# Checkpoint must NOT appear in loras
|
||||
lora_names = {l["name"] for l in result["loras"]}
|
||||
assert "My Checkpoint" not in lora_names
|
||||
assert "Style LoRA" in lora_names
|
||||
|
||||
# Checkpoint must be in result["model"]
|
||||
assert result["model"] is not None
|
||||
assert result["model"]["name"] == "My Checkpoint"
|
||||
|
||||
Reference in New Issue
Block a user