fix(checkpoints): preserve model type on persisted load

This commit is contained in:
pixelpaws
2025-10-21 22:55:00 +08:00
parent e63ef8d031
commit c5175bb870
3 changed files with 201 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import Any, Dict, List, Optional
from ..utils.models import CheckpointMetadata
from ..config import config
@@ -21,14 +21,33 @@ class CheckpointScanner(ModelScanner):
hash_index=ModelHashIndex()
)
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
if not root_path:
return None
if config.checkpoints_roots and root_path in config.checkpoints_roots:
return "checkpoint"
if config.unet_roots and root_path in config.unet_roots:
return "diffusion_model"
return None
def adjust_metadata(self, metadata, file_path, root_path):
if hasattr(metadata, "model_type"):
if root_path in config.checkpoints_roots:
metadata.model_type = "checkpoint"
elif root_path in config.unet_roots:
metadata.model_type = "diffusion_model"
model_type = self._resolve_model_type(root_path)
if model_type:
metadata.model_type = model_type
return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
model_type = self._resolve_model_type(
self._find_root_for_file(entry.get("file_path"))
)
if model_type:
entry["model_type"] = model_type
return entry
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return config.base_models_roots
return config.base_models_roots