mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-09 12:39:23 -03:00
fix(recipe): use resources type field to identify checkpoint instead of modelVersionIds[0]
When importing a CivitAI image as a recipe, modelVersionIds[0] was blindly used as the checkpoint version ID. This array mixes checkpoints and LoRAs without ordering guarantees, causing LoRAs to be saved as the recipe checkpoint. Fix by: 1. Removing the modelVersionIds[0] fallback in _download_remote_media 2. Parsing resources entries with type:"model" as the checkpoint 3. Adding model type validation in populate_checkpoint_from_civitai Also add 2 tests for the new behavior and fix 3 tests whose mocks lacked the required model.type field.
This commit is contained in:
@@ -7,7 +7,7 @@ import re
|
|||||||
from typing import Dict, List, Any, Optional, Tuple
|
from typing import Dict, List, Any, Optional, Tuple
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.constants import VALID_LORA_TYPES
|
from ..utils.constants import VALID_LORA_TYPES, VALID_CHECKPOINT_SUB_TYPES
|
||||||
from ..utils.civitai_utils import rewrite_preview_url
|
from ..utils.civitai_utils import rewrite_preview_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -173,6 +173,20 @@ class RecipeMetadataParser(ABC):
|
|||||||
checkpoint['isDeleted'] = True
|
checkpoint['isDeleted'] = True
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
# Validate that the model type is actually a checkpoint.
|
||||||
|
# Unlike populate_lora_from_civitai which has this check,
|
||||||
|
# this function was missing type validation — allowing LoRA
|
||||||
|
# version data to be saved as the recipe's checkpoint when the
|
||||||
|
# wrong version ID was passed downstream (fixed in v2.7+).
|
||||||
|
model_type = civitai_data.get('model', {}).get('type', '').lower()
|
||||||
|
if model_type not in VALID_CHECKPOINT_SUB_TYPES:
|
||||||
|
logger.warning(
|
||||||
|
f"Cannot populate checkpoint: model version {civitai_data.get('id')} "
|
||||||
|
f"has type '{model_type}', expected one of {VALID_CHECKPOINT_SUB_TYPES}. "
|
||||||
|
f"Skipping checkpoint enrichment."
|
||||||
|
)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
if 'model' in civitai_data and 'name' in civitai_data['model']:
|
if 'model' in civitai_data and 'name' in civitai_data['model']:
|
||||||
checkpoint['name'] = civitai_data['model']['name']
|
checkpoint['name'] = civitai_data['model']['name']
|
||||||
|
|
||||||
|
|||||||
@@ -185,8 +185,67 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
# Process standard resources array
|
# Process standard resources array
|
||||||
if "resources" in metadata and isinstance(metadata["resources"], list):
|
if "resources" in metadata and isinstance(metadata["resources"], list):
|
||||||
for resource in metadata["resources"]:
|
for resource in metadata["resources"]:
|
||||||
|
resource_type = resource.get("type", "lora")
|
||||||
|
|
||||||
|
# Track resources with type "model" — these are checkpoint models.
|
||||||
|
# The resources array is the most reliable source for checkpoint
|
||||||
|
# identification because it has an explicit type field and hash,
|
||||||
|
# unlike modelVersionIds which is a flat list with no type info.
|
||||||
|
if resource_type == "model":
|
||||||
|
checkpoint_entry = {
|
||||||
|
"id": 0,
|
||||||
|
"modelId": 0,
|
||||||
|
"name": resource.get("name", "Unknown Model"),
|
||||||
|
"version": "",
|
||||||
|
"type": resource.get("type", "model"),
|
||||||
|
"existsLocally": False,
|
||||||
|
"localPath": None,
|
||||||
|
"file_name": resource.get("name", ""),
|
||||||
|
"hash": resource.get("hash", "") or "",
|
||||||
|
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||||
|
"baseModel": "",
|
||||||
|
"size": 0,
|
||||||
|
"downloadUrl": "",
|
||||||
|
"isDeleted": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to look up base model from the checkpoint hash
|
||||||
|
if checkpoint_entry["hash"] and metadata_provider:
|
||||||
|
try:
|
||||||
|
civitai_info = (
|
||||||
|
await metadata_provider.get_model_by_hash(
|
||||||
|
checkpoint_entry["hash"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
civitai_data, error_msg = (
|
||||||
|
(civitai_info, None)
|
||||||
|
if not isinstance(civitai_info, tuple)
|
||||||
|
else civitai_info
|
||||||
|
)
|
||||||
|
if civitai_data and error_msg != "Model not found":
|
||||||
|
if 'model' in civitai_data and 'name' in civitai_data['model']:
|
||||||
|
checkpoint_entry['name'] = civitai_data['model']['name']
|
||||||
|
checkpoint_entry['id'] = civitai_data.get('id', 0)
|
||||||
|
checkpoint_entry['modelId'] = civitai_data.get('modelId', 0)
|
||||||
|
if 'name' in civitai_data:
|
||||||
|
checkpoint_entry['version'] = civitai_data['name']
|
||||||
|
base_model = civitai_data.get('baseModel', '')
|
||||||
|
if base_model:
|
||||||
|
checkpoint_entry['baseModel'] = base_model
|
||||||
|
if not result['base_model']:
|
||||||
|
result['base_model'] = base_model
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error fetching checkpoint info for hash "
|
||||||
|
f"{checkpoint_entry['hash']}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if result["model"] is None:
|
||||||
|
result["model"] = checkpoint_entry
|
||||||
|
continue
|
||||||
|
|
||||||
# Modified to process resources without a type field as potential LoRAs
|
# Modified to process resources without a type field as potential LoRAs
|
||||||
if resource.get("type", "lora") == "lora":
|
if resource_type == "lora":
|
||||||
lora_hash = resource.get("hash", "")
|
lora_hash = resource.get("hash", "")
|
||||||
|
|
||||||
# Try to get hash from the hashes field if not present in resource
|
# Try to get hash from the hashes field if not present in resource
|
||||||
|
|||||||
@@ -1293,11 +1293,18 @@ class RecipeManagementHandler:
|
|||||||
image_info.get("meta") if civitai_image_id and image_info else None
|
image_info.get("meta") if civitai_image_id and image_info else None
|
||||||
)
|
)
|
||||||
if civitai_image_id and image_info:
|
if civitai_image_id and image_info:
|
||||||
|
# modelVersionId (singular) — the primary version for this
|
||||||
|
# image on CivitAI. May be absent, or may *not* be the
|
||||||
|
# checkpoint (e.g. when the image was generated with a LoRA
|
||||||
|
# as the primary subject). When absent, DO NOT fall back to
|
||||||
|
# modelVersionIds[0] — that array mixes checkpoints, LoRAs,
|
||||||
|
# and other model version IDs without ordering guarantees.
|
||||||
|
# The downstream enrichment flow will find the real
|
||||||
|
# checkpoint via meta.resources (type:"model" hash) or
|
||||||
|
# meta.civitaiResources (type:"checkpoint" version ID), so
|
||||||
|
# leaving model_ver_id as None is safe and avoids the bug
|
||||||
|
# where a LoRA version ID was treated as the checkpoint.
|
||||||
model_ver_id = image_info.get("modelVersionId")
|
model_ver_id = image_info.get("modelVersionId")
|
||||||
if not model_ver_id:
|
|
||||||
ids = image_info.get("modelVersionIds")
|
|
||||||
if isinstance(ids, list) and ids:
|
|
||||||
model_ver_id = ids[0]
|
|
||||||
|
|
||||||
# Inject root-level modelVersionIds into meta so downstream
|
# Inject root-level modelVersionIds into meta so downstream
|
||||||
# parsers (CivitaiApiMetadataParser) can discover ALL resources
|
# parsers (CivitaiApiMetadataParser) can discover ALL resources
|
||||||
|
|||||||
@@ -467,7 +467,10 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
|||||||
class Provider:
|
class Provider:
|
||||||
async def get_model_version_info(self, model_version_id):
|
async def get_model_version_info(self, model_version_id):
|
||||||
provider_calls.append(model_version_id)
|
provider_calls.append(model_version_id)
|
||||||
return {"baseModel": "Flux Provider"}, None
|
return {
|
||||||
|
"baseModel": "Flux Provider",
|
||||||
|
"model": {"type": "Checkpoint", "name": "Flux"},
|
||||||
|
}, None
|
||||||
|
|
||||||
async def fake_get_default_metadata_provider():
|
async def fake_get_default_metadata_provider():
|
||||||
return Provider()
|
return Provider()
|
||||||
|
|||||||
@@ -298,3 +298,113 @@ async def test_parse_metadata_handles_modelVersionIds(monkeypatch):
|
|||||||
assert lora2["type"] == "lora"
|
assert lora2["type"] == "lora"
|
||||||
assert lora2["hash"] == "aabbccdd0022"
|
assert lora2["hash"] == "aabbccdd0022"
|
||||||
assert lora2["baseModel"] == "SDXL"
|
assert lora2["baseModel"] == "SDXL"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_extracts_checkpoint_from_resources_model_type(monkeypatch):
|
||||||
|
"""resources entries with type:"model" should be captured as the checkpoint,
|
||||||
|
not skipped (which was the old buggy behavior), and not mixed into loras."""
|
||||||
|
captured_hashes = []
|
||||||
|
|
||||||
|
async def fake_metadata_provider():
|
||||||
|
class Provider:
|
||||||
|
async def get_model_by_hash(self, model_hash):
|
||||||
|
captured_hashes.append(model_hash)
|
||||||
|
if model_hash == "a1b2c3d4e5":
|
||||||
|
return ({
|
||||||
|
"id": 999,
|
||||||
|
"modelId": 888,
|
||||||
|
"name": "v1.0",
|
||||||
|
"model": {"name": "Real Checkpoint", "type": "Checkpoint"},
|
||||||
|
"baseModel": "SDXL 1.0",
|
||||||
|
"images": [{"url": "https://image.civitai.com/cp/original=true"}],
|
||||||
|
"files": [{"type": "Model", "primary": True, "sizeKB": 1024, "name": "cp.safetensors"}]
|
||||||
|
}, None)
|
||||||
|
return None, "Model not found"
|
||||||
|
|
||||||
|
return Provider()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.recipes.parsers.civitai_image.get_default_metadata_provider",
|
||||||
|
fake_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = CivitaiApiMetadataParser()
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"prompt": "test",
|
||||||
|
"resources": [
|
||||||
|
{"hash": "a1b2c3d4e5", "name": "Real Checkpoint", "type": "model"},
|
||||||
|
{"hash": "f6g7h8i9j0", "name": "Some LoRA", "type": "lora", "weight": 0.8},
|
||||||
|
],
|
||||||
|
"Model hash": "a1b2c3d4e5",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await parser.parse_metadata(metadata)
|
||||||
|
|
||||||
|
# The type:"model" resource should be in result["model"], not in result["loras"]
|
||||||
|
assert result["model"] is not None, "checkpoint model should be extracted"
|
||||||
|
assert result["model"]["name"] == "Real Checkpoint"
|
||||||
|
assert result["model"]["hash"] == "a1b2c3d4e5"
|
||||||
|
assert result["model"]["type"] == "model"
|
||||||
|
|
||||||
|
# The LoRA resource should be in result["loras"]
|
||||||
|
assert len(result["loras"]) == 1
|
||||||
|
assert result["loras"][0]["name"] == "Some LoRA"
|
||||||
|
|
||||||
|
# The checkpoint hash should have triggered a lookup
|
||||||
|
assert "a1b2c3d4e5" in captured_hashes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_metadata_resources_model_type_does_not_duplicate_checkpoint_in_loras(monkeypatch):
|
||||||
|
"""When a resources entry has type:"model", it should NOT also appear in loras.
|
||||||
|
Regression test for the bug where the checkpoint model appeared in both places."""
|
||||||
|
async def fake_metadata_provider():
|
||||||
|
class Provider:
|
||||||
|
async def get_model_by_hash(self, model_hash):
|
||||||
|
if model_hash == "cp123hash":
|
||||||
|
return ({
|
||||||
|
"id": 100,
|
||||||
|
"modelId": 200,
|
||||||
|
"name": "v2",
|
||||||
|
"model": {"name": "My Checkpoint", "type": "Checkpoint"},
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"files": [{"type": "Model", "primary": True, "sizeKB": 1024, "name": "cp.safetensors"}]
|
||||||
|
}, None)
|
||||||
|
if model_hash == "lora1hash":
|
||||||
|
return ({
|
||||||
|
"id": 300,
|
||||||
|
"modelId": 400,
|
||||||
|
"name": "v1",
|
||||||
|
"model": {"name": "Style LoRA", "type": "LORA"},
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"files": [{"type": "Model", "primary": True, "sizeKB": 512, "name": "style.safetensors"}]
|
||||||
|
}, None)
|
||||||
|
return None, "Model not found"
|
||||||
|
|
||||||
|
return Provider()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.recipes.parsers.civitai_image.get_default_metadata_provider",
|
||||||
|
fake_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = CivitaiApiMetadataParser()
|
||||||
|
metadata = {
|
||||||
|
"resources": [
|
||||||
|
{"hash": "cp123hash", "name": "My Checkpoint", "type": "model"},
|
||||||
|
{"hash": "lora1hash", "name": "Style LoRA", "type": "lora", "weight": 0.5},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await parser.parse_metadata(metadata)
|
||||||
|
|
||||||
|
# Checkpoint must NOT appear in loras
|
||||||
|
lora_names = {l["name"] for l in result["loras"]}
|
||||||
|
assert "My Checkpoint" not in lora_names
|
||||||
|
assert "Style LoRA" in lora_names
|
||||||
|
|
||||||
|
# Checkpoint must be in result["model"]
|
||||||
|
assert result["model"] is not None
|
||||||
|
assert result["model"]["name"] == "My Checkpoint"
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ async def test_repair_all_recipes_with_enriched_checkpoint_id(setup_scanner):
|
|||||||
"id": 5678,
|
"id": 5678,
|
||||||
"modelId": 1234,
|
"modelId": 1234,
|
||||||
"name": "v1.0",
|
"name": "v1.0",
|
||||||
"model": {"name": "Full Model Name"},
|
"model": {"name": "Full Model Name", "type": "Checkpoint"},
|
||||||
"baseModel": "SDXL 1.0",
|
"baseModel": "SDXL 1.0",
|
||||||
"images": [{"url": "https://image.url/thumb.jpg"}],
|
"images": [{"url": "https://image.url/thumb.jpg"}],
|
||||||
"files": [{"type": "Model", "hashes": {"SHA256": "ABCDEF"}, "name": "full_filename.safetensors"}]
|
"files": [{"type": "Model", "hashes": {"SHA256": "ABCDEF"}, "name": "full_filename.safetensors"}]
|
||||||
@@ -142,7 +142,7 @@ async def test_repair_all_recipes_supports_civitai_red_source_url(setup_scanner)
|
|||||||
"id": 5678,
|
"id": 5678,
|
||||||
"modelId": 1234,
|
"modelId": 1234,
|
||||||
"name": "v1.0",
|
"name": "v1.0",
|
||||||
"model": {"name": "Full Model Name"},
|
"model": {"name": "Full Model Name", "type": "Checkpoint"},
|
||||||
"baseModel": "SDXL 1.0",
|
"baseModel": "SDXL 1.0",
|
||||||
"images": [{"url": "https://image.url/thumb.jpg"}],
|
"images": [{"url": "https://image.url/thumb.jpg"}],
|
||||||
"files": [
|
"files": [
|
||||||
@@ -183,7 +183,7 @@ async def test_repair_all_recipes_with_enriched_checkpoint_hash(setup_scanner):
|
|||||||
"id": 999,
|
"id": 999,
|
||||||
"modelId": 888,
|
"modelId": 888,
|
||||||
"name": "v2.0",
|
"name": "v2.0",
|
||||||
"model": {"name": "Hashed Model"},
|
"model": {"name": "Hashed Model", "type": "Checkpoint"},
|
||||||
"baseModel": "SD 1.5",
|
"baseModel": "SD 1.5",
|
||||||
"files": [{"type": "Model", "hashes": {"SHA256": "hash123"}, "name": "hashed.safetensors"}]
|
"files": [{"type": "Model", "hashes": {"SHA256": "hash123"}, "name": "hashed.safetensors"}]
|
||||||
}, None)
|
}, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user