feat: add checkpoint scanner integration to recipe scanner

- Add CheckpointScanner dependency to RecipeScanner singleton
- Implement checkpoint enrichment in recipe data processing
- Add _enrich_checkpoint_entry method to enhance checkpoint metadata
- Update recipe formatting to include checkpoint information
- Extend get_instance, __new__, and __init__ methods to support checkpoint scanner
- Add _get_checkpoint_from_version_index method for cache lookup

This enables recipe scanner to handle checkpoint models alongside existing LoRA support, providing complete model metadata for recipes.
This commit is contained in:
Will Miao
2025-11-21 15:36:54 +08:00
parent 4eb46a8d3e
commit 1971881537
4 changed files with 396 additions and 15 deletions

View File

@@ -9,6 +9,7 @@ from .recipe_cache import RecipeCache
from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner
from .metadata_service import get_default_metadata_provider
from .checkpoint_scanner import CheckpointScanner
from .recipes.errors import RecipeNotFoundError
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
from natsort import natsorted
@@ -23,24 +24,39 @@ class RecipeScanner:
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None):
async def get_instance(
cls,
lora_scanner: Optional[LoraScanner] = None,
checkpoint_scanner: Optional[CheckpointScanner] = None,
):
"""Get singleton instance of RecipeScanner"""
async with cls._lock:
if cls._instance is None:
if not lora_scanner:
# Get lora scanner from service registry if not provided
lora_scanner = await ServiceRegistry.get_lora_scanner()
cls._instance = cls(lora_scanner)
if not checkpoint_scanner:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
cls._instance = cls(lora_scanner, checkpoint_scanner)
return cls._instance
def __new__(cls, lora_scanner: Optional[LoraScanner] = None):
def __new__(
cls,
lora_scanner: Optional[LoraScanner] = None,
checkpoint_scanner: Optional[CheckpointScanner] = None,
):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._lora_scanner = lora_scanner
cls._instance._checkpoint_scanner = checkpoint_scanner
cls._instance._civitai_client = None # Will be lazily initialized
return cls._instance
def __init__(self, lora_scanner: Optional[LoraScanner] = None):
def __init__(
self,
lora_scanner: Optional[LoraScanner] = None,
checkpoint_scanner: Optional[CheckpointScanner] = None,
):
# Ensure initialization only happens once
if not hasattr(self, '_initialized'):
self._cache: Optional[RecipeCache] = None
@@ -51,6 +67,8 @@ class RecipeScanner:
self._resort_tasks: Set[asyncio.Task] = set()
if lora_scanner:
self._lora_scanner = lora_scanner
if checkpoint_scanner:
self._checkpoint_scanner = checkpoint_scanner
self._initialized = True
def on_library_changed(self) -> None:
@@ -422,6 +440,9 @@ class RecipeScanner:
# Update lora information with local paths and availability
await self._update_lora_information(recipe_data)
if recipe_data.get('checkpoint'):
recipe_data['checkpoint'] = self._enrich_checkpoint_entry(dict(recipe_data['checkpoint']))
# Calculate and update fingerprint if missing
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
@@ -585,6 +606,27 @@ class RecipeScanner:
return version_index.get(normalized_id)
def _get_checkpoint_from_version_index(self, model_version_id: Any) -> Optional[Dict[str, Any]]:
"""Fetch a cached checkpoint entry by version id."""
if not self._checkpoint_scanner:
return None
cache = getattr(self._checkpoint_scanner, "_cache", None)
if cache is None:
return None
version_index = getattr(cache, "version_index", None)
if not version_index:
return None
try:
normalized_id = int(model_version_id)
except (TypeError, ValueError):
return None
return version_index.get(normalized_id)
async def _determine_base_model(self, loras: List[Dict]) -> Optional[str]:
"""Determine the most common base model among LoRAs"""
base_models = {}
@@ -623,6 +665,57 @@ class RecipeScanner:
logger.error(f"Error getting base model for lora: {e}")
return None
def _enrich_checkpoint_entry(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
"""Populate convenience fields for a checkpoint entry."""
if not checkpoint or not isinstance(checkpoint, dict) or not self._checkpoint_scanner:
return checkpoint
hash_value = (checkpoint.get('hash') or '').lower()
version_entry = None
model_version_id = checkpoint.get('id') or checkpoint.get('modelVersionId')
if not hash_value and model_version_id is not None:
version_entry = self._get_checkpoint_from_version_index(model_version_id)
try:
preview_url = checkpoint.get('preview_url') or checkpoint.get('thumbnailUrl')
if preview_url:
checkpoint['preview_url'] = self._normalize_preview_url(preview_url)
if hash_value:
checkpoint['inLibrary'] = self._checkpoint_scanner.has_hash(hash_value)
checkpoint['preview_url'] = self._normalize_preview_url(
checkpoint.get('preview_url')
or self._checkpoint_scanner.get_preview_url_by_hash(hash_value)
)
checkpoint['localPath'] = self._checkpoint_scanner.get_path_by_hash(hash_value)
elif version_entry:
checkpoint['inLibrary'] = True
cached_path = version_entry.get('file_path') or version_entry.get('path')
if cached_path:
checkpoint.setdefault('localPath', cached_path)
if not checkpoint.get('file_name'):
checkpoint['file_name'] = os.path.splitext(os.path.basename(cached_path))[0]
if version_entry.get('sha256') and not checkpoint.get('hash'):
checkpoint['hash'] = version_entry.get('sha256')
preview_url = self._normalize_preview_url(version_entry.get('preview_url'))
if preview_url:
checkpoint.setdefault('preview_url', preview_url)
if version_entry.get('model_type'):
checkpoint.setdefault('model_type', version_entry.get('model_type'))
else:
checkpoint.setdefault('inLibrary', False)
if checkpoint.get('preview_url'):
checkpoint['preview_url'] = self._normalize_preview_url(checkpoint['preview_url'])
except Exception as exc: # pragma: no cover - defensive logging
logger.debug("Error enriching checkpoint entry %s: %s", hash_value or model_version_id, exc)
return checkpoint
def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]:
"""Populate convenience fields for a LoRA entry."""
@@ -827,6 +920,8 @@ class RecipeScanner:
for item in paginated_items:
if 'loras' in item:
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']]
if item.get('checkpoint'):
item['checkpoint'] = self._enrich_checkpoint_entry(dict(item['checkpoint']))
result = {
'items': paginated_items,
@@ -874,6 +969,8 @@ class RecipeScanner:
# Add lora metadata
if 'loras' in formatted_recipe:
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']]
if formatted_recipe.get('checkpoint'):
formatted_recipe['checkpoint'] = self._enrich_checkpoint_entry(dict(formatted_recipe['checkpoint']))
return formatted_recipe