Files
ComfyUI-Lora-Manager/py/services/base_model_service.py

364 lines
14 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type
import logging
import os
from ..utils.models import BaseModelMetadata
from ..utils.routes_common import ModelRouteUtils
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .settings_manager import settings as default_settings
logger = logging.getLogger(__name__)
class BaseModelService(ABC):
"""Base service class for all model types"""
def __init__(
self,
model_type: str,
scanner,
metadata_class: Type[BaseModelMetadata],
*,
cache_repository: Optional[ModelCacheRepository] = None,
filter_set: Optional[ModelFilterSet] = None,
search_strategy: Optional[SearchStrategy] = None,
settings_provider: Optional[SettingsProvider] = None,
):
"""Initialize the service.
Args:
model_type: Type of model (lora, checkpoint, etc.).
scanner: Model scanner instance.
metadata_class: Metadata class for this model type.
cache_repository: Custom repository for cache access (primarily for tests).
filter_set: Filter component controlling folder/tag/favorites logic.
search_strategy: Search component for fuzzy/text matching.
settings_provider: Settings object; defaults to the global settings manager.
"""
self.model_type = model_type
self.scanner = scanner
self.metadata_class = metadata_class
self.settings = settings_provider or default_settings
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy()
async def get_paginated_data(
self,
page: int,
page_size: int,
sort_by: str = 'name',
folder: str = None,
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
tags: list = None,
search_options: dict = None,
hash_filters: dict = None,
favorites_only: bool = False,
**kwargs,
) -> Dict:
"""Get paginated and filtered model data"""
sort_params = self.cache_repository.parse_sort(sort_by)
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
if hash_filters:
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
return self._paginate(filtered_data, page, page_size)
filtered_data = await self._apply_common_filters(
sorted_data,
folder=folder,
base_models=base_models,
tags=tags,
favorites_only=favorites_only,
search_options=search_options,
)
if search:
filtered_data = await self._apply_search_filters(
filtered_data,
search,
fuzzy_search,
search_options,
)
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
return self._paginate(filtered_data, page, page_size)
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower()
return [
item for item in data
if item.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes)
return [
item for item in data
if item.get('sha256', '').lower() in hash_set
]
return data
async def _apply_common_filters(
self,
data: List[Dict],
folder: str = None,
base_models: list = None,
tags: list = None,
favorites_only: bool = False,
search_options: dict = None,
) -> List[Dict]:
"""Apply common filters that work across all model types"""
normalized_options = self.search_strategy.normalize_options(search_options)
criteria = FilterCriteria(
folder=folder,
base_models=base_models,
tags=tags,
favorites_only=favorites_only,
search_options=normalized_options,
)
return self.filter_set.apply(data, criteria)
async def _apply_search_filters(
self,
data: List[Dict],
search: str,
fuzzy_search: bool,
search_options: dict,
) -> List[Dict]:
"""Apply search filtering"""
normalized_options = self.search_strategy.normalize_options(search_options)
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""
return data
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data"""
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
@abstractmethod
async def format_response(self, model_data: Dict) -> Dict:
"""Format model data for API response - must be implemented by subclasses"""
pass
# Common service methods that delegate to scanner
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
"""Get top tags sorted by frequency"""
return await self.scanner.get_top_tags(limit)
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self.scanner.has_hash(sha256)
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self.scanner.get_path_by_hash(sha256)
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self.scanner.get_hash_by_path(file_path)
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
"""Trigger model scanning"""
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
async def get_model_info_by_name(self, name: str):
"""Get model information by name"""
return await self.scanner.get_model_info_by_name(name)
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
return self.scanner.get_model_roots()
async def get_folder_tree(self, model_root: str) -> Dict:
"""Get hierarchical folder tree for a specific model root"""
cache = await self.scanner.get_cached_data()
# Build tree structure from folders
tree = {}
for folder in cache.folders:
# Check if this folder belongs to the specified model root
folder_belongs_to_root = False
for root in self.scanner.get_model_roots():
if root == model_root:
folder_belongs_to_root = True
break
if not folder_belongs_to_root:
continue
# Split folder path into components
parts = folder.split('/') if folder else []
current_level = tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return tree
async def get_unified_folder_tree(self) -> Dict:
"""Get unified folder tree across all model roots"""
cache = await self.scanner.get_cached_data()
# Build unified tree structure by analyzing all relative paths
unified_tree = {}
# Get all model roots for path normalization
model_roots = self.scanner.get_model_roots()
for folder in cache.folders:
if not folder: # Skip empty folders
continue
# Find which root this folder belongs to by checking the actual file paths
# This is a simplified approach - we'll use the folder as-is since it should already be relative
relative_path = folder
# Split folder path into components
parts = relative_path.split('/')
current_level = unified_tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return unified_tree
async def get_model_notes(self, model_name: str) -> Optional[str]:
"""Get notes for a specific model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
return model.get('notes', '')
return None
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
"""Get the static preview URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
preview_url = model.get('preview_url')
if preview_url:
from ..config import config
return config.get_preview_static_url(preview_url)
return '/loras_static/images/no-preview.png'
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
civitai_data = model.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
"""Get filtered CivitAI metadata for a model by file path"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model.get('file_path') == file_path:
return ModelRouteUtils.filter_civitai_data(model.get("civitai", {}))
return None
async def get_model_description(self, file_path: str) -> Optional[str]:
"""Get model description by file path"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model.get('file_path') == file_path:
return model.get('modelDescription', '')
return None
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
"""Search model relative file paths for autocomplete functionality"""
cache = await self.scanner.get_cached_data()
matching_paths = []
search_lower = search_term.lower()
# Get model roots for path calculation
model_roots = self.scanner.get_model_roots()
for model in cache.raw_data:
file_path = model.get('file_path', '')
if not file_path:
continue
# Calculate relative path from model root
relative_path = None
for root in model_roots:
# Normalize paths for comparison
normalized_root = os.path.normpath(root)
normalized_file = os.path.normpath(file_path)
if normalized_file.startswith(normalized_root):
# Remove root and leading separator to get relative path
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
break
if relative_path and search_lower in relative_path.lower():
matching_paths.append(relative_path)
if len(matching_paths) >= limit * 2: # Get more for better sorting
break
# Sort by relevance (exact matches first, then by length)
matching_paths.sort(key=lambda x: (
not x.lower().startswith(search_lower), # Exact prefix matches first
len(x), # Then by length (shorter first)
x.lower() # Then alphabetically
))
return matching_paths[:limit]