mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
refactor: remove ModelRouteUtils usage and implement filtering directly in services
This commit is contained in:
@@ -31,7 +31,6 @@ from ..services.websocket_progress_callback import (
|
|||||||
)
|
)
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||||
from .handlers.model_handlers import (
|
from .handlers.model_handlers import (
|
||||||
ModelAutoOrganizeHandler,
|
ModelAutoOrganizeHandler,
|
||||||
@@ -236,10 +235,6 @@ class BaseModelRoutes(ABC):
|
|||||||
"""Expose handlers for subclasses or tests."""
|
"""Expose handlers for subclasses or tests."""
|
||||||
return self._ensure_handler_mapping()[name]
|
return self._ensure_handler_mapping()[name]
|
||||||
|
|
||||||
@property
|
|
||||||
def utils(self) -> ModelRouteUtils: # pragma: no cover - compatibility shim
|
|
||||||
return ModelRouteUtils
|
|
||||||
|
|
||||||
def _ensure_service(self):
|
def _ensure_service(self):
|
||||||
if self.service is None:
|
if self.service is None:
|
||||||
raise RuntimeError("Model service has not been attached")
|
raise RuntimeError("Model service has not been attached")
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||||
from .settings_manager import settings as default_settings
|
from .settings_manager import settings as default_settings
|
||||||
|
|
||||||
@@ -197,6 +196,18 @@ class BaseModelService(ABC):
|
|||||||
"""Get model root directories"""
|
"""Get model root directories"""
|
||||||
return self.scanner.get_model_roots()
|
return self.scanner.get_model_roots()
|
||||||
|
|
||||||
|
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||||
|
"""Filter relevant fields from CivitAI data"""
|
||||||
|
if not data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||||
|
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||||
|
"publishedAt", "trainedWords", "baseModel", "description",
|
||||||
|
"model", "images", "customImages", "creator"
|
||||||
|
]
|
||||||
|
return {k: data[k] for k in fields if k in data}
|
||||||
|
|
||||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||||
"""Get hierarchical folder tree for a specific model root"""
|
"""Get hierarchical folder tree for a specific model root"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
@@ -307,7 +318,7 @@ class BaseModelService(ABC):
|
|||||||
|
|
||||||
for model in cache.raw_data:
|
for model in cache.raw_data:
|
||||||
if model.get('file_path') == file_path:
|
if model.get('file_path') == file_path:
|
||||||
return ModelRouteUtils.filter_civitai_data(model.get("civitai", {}))
|
return self.filter_civitai_data(model.get("civitai", {}))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from ..utils.models import CheckpointMetadata
|
from ..utils.models import CheckpointMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ class CheckpointService(BaseModelService):
|
|||||||
"notes": checkpoint_data.get("notes", ""),
|
"notes": checkpoint_data.get("notes", ""),
|
||||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||||
"favorite": checkpoint_data.get("favorite", False),
|
"favorite": checkpoint_data.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_duplicate_hashes(self) -> Dict:
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from ..utils.models import EmbeddingMetadata
|
from ..utils.models import EmbeddingMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ class EmbeddingService(BaseModelService):
|
|||||||
"notes": embedding_data.get("notes", ""),
|
"notes": embedding_data.get("notes", ""),
|
||||||
"model_type": embedding_data.get("model_type", "embedding"),
|
"model_type": embedding_data.get("model_type", "embedding"),
|
||||||
"favorite": embedding_data.get("favorite", False),
|
"favorite": embedding_data.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_duplicate_hashes(self) -> Dict:
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
|||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ class LoraService(BaseModelService):
|
|||||||
"usage_tips": lora_data.get("usage_tips", ""),
|
"usage_tips": lora_data.get("usage_tips", ""),
|
||||||
"notes": lora_data.get("notes", ""),
|
"notes": lora_data.get("notes", ""),
|
||||||
"favorite": lora_data.get("favorite", False),
|
"favorite": lora_data.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ class MetadataUpdater:
|
|||||||
# Track that we're refreshing this model
|
# Track that we're refreshing this model
|
||||||
download_progress['refreshed_models'].add(model_hash)
|
download_progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
# Use ModelRouteUtils to refresh metadata
|
|
||||||
async def update_cache_func(old_path, new_path, metadata):
|
async def update_cache_func(old_path, new_path, metadata):
|
||||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..services.settings_manager import settings
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TODO: retire this class
|
||||||
class ModelRouteUtils:
|
class ModelRouteUtils:
|
||||||
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
|
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class MockHashIndex:
|
|||||||
|
|
||||||
|
|
||||||
class MockCache:
|
class MockCache:
|
||||||
"""Cache object with the attributes consumed by ``ModelRouteUtils``."""
|
"""Cache object with the attributes."""
|
||||||
|
|
||||||
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
|
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
|
||||||
self.raw_data: List[Dict[str, Any]] = list(items or [])
|
self.raw_data: List[Dict[str, Any]] = list(items or [])
|
||||||
@@ -89,7 +89,7 @@ class MockCache:
|
|||||||
|
|
||||||
async def resort(self) -> None:
|
async def resort(self) -> None:
|
||||||
self.resort_calls += 1
|
self.resort_calls += 1
|
||||||
# ``ModelRouteUtils`` expects the coroutine interface but does not
|
# expects the coroutine interface but does not
|
||||||
# rely on the return value.
|
# rely on the return value.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user