mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
fix(checkpoints): preserve model type on persisted load
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
@@ -21,14 +21,33 @@ class CheckpointScanner(ModelScanner):
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||
return "checkpoint"
|
||||
|
||||
if config.unet_roots and root_path in config.unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
if hasattr(metadata, "model_type"):
|
||||
if root_path in config.checkpoints_roots:
|
||||
metadata.model_type = "checkpoint"
|
||||
elif root_path in config.unet_roots:
|
||||
metadata.model_type = "diffusion_model"
|
||||
model_type = self._resolve_model_type(root_path)
|
||||
if model_type:
|
||||
metadata.model_type = model_type
|
||||
return metadata
|
||||
|
||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model_type = self._resolve_model_type(
|
||||
self._find_root_for_file(entry.get("file_path"))
|
||||
)
|
||||
if model_type:
|
||||
entry["model_type"] = model_type
|
||||
return entry
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return config.base_models_roots
|
||||
return config.base_models_roots
|
||||
|
||||
Reference in New Issue
Block a user