from abc import ABC, abstractmethod from typing import Dict, List, Optional, Type import logging import os from ..utils.models import BaseModelMetadata from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from .settings_manager import settings as default_settings logger = logging.getLogger(__name__) class BaseModelService(ABC): """Base service class for all model types""" def __init__( self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata], *, cache_repository: Optional[ModelCacheRepository] = None, filter_set: Optional[ModelFilterSet] = None, search_strategy: Optional[SearchStrategy] = None, settings_provider: Optional[SettingsProvider] = None, ): """Initialize the service. Args: model_type: Type of model (lora, checkpoint, etc.). scanner: Model scanner instance. metadata_class: Metadata class for this model type. cache_repository: Custom repository for cache access (primarily for tests). filter_set: Filter component controlling folder/tag/favorites logic. search_strategy: Search component for fuzzy/text matching. settings_provider: Settings object; defaults to the global settings manager. """ self.model_type = model_type self.scanner = scanner self.metadata_class = metadata_class self.settings = settings_provider or default_settings self.cache_repository = cache_repository or ModelCacheRepository(scanner) self.filter_set = filter_set or ModelFilterSet(self.settings) self.search_strategy = search_strategy or SearchStrategy() async def get_paginated_data( self, page: int, page_size: int, sort_by: str = 'name', folder: str = None, search: str = None, fuzzy_search: bool = False, base_models: list = None, tags: list = None, search_options: dict = None, hash_filters: dict = None, favorites_only: bool = False, **kwargs, ) -> Dict: """Get paginated and filtered model data""" sort_params = self.cache_repository.parse_sort(sort_by) sorted_data = await self.cache_repository.fetch_sorted(sort_params) if hash_filters: filtered_data = await self._apply_hash_filters(sorted_data, hash_filters) return self._paginate(filtered_data, page, page_size) filtered_data = await self._apply_common_filters( sorted_data, folder=folder, base_models=base_models, tags=tags, favorites_only=favorites_only, search_options=search_options, ) if search: filtered_data = await self._apply_search_filters( filtered_data, search, fuzzy_search, search_options, ) filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) return self._paginate(filtered_data, page, page_size) async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: """Apply hash-based filtering""" single_hash = hash_filters.get('single_hash') multiple_hashes = hash_filters.get('multiple_hashes') if single_hash: # Filter by single hash single_hash = single_hash.lower() return [ item for item in data if item.get('sha256', '').lower() == single_hash ] elif multiple_hashes: # Filter by multiple hashes hash_set = set(hash.lower() for hash in multiple_hashes) return [ item for item in data if item.get('sha256', '').lower() in hash_set ] return data async def _apply_common_filters( self, data: List[Dict], folder: str = None, base_models: list = None, tags: list = None, favorites_only: bool = False, search_options: dict = None, ) -> List[Dict]: """Apply common filters that work across all model types""" normalized_options = self.search_strategy.normalize_options(search_options) criteria = FilterCriteria( folder=folder, base_models=base_models, tags=tags, favorites_only=favorites_only, search_options=normalized_options, ) return self.filter_set.apply(data, criteria) async def _apply_search_filters( self, data: List[Dict], search: str, fuzzy_search: bool, search_options: dict, ) -> List[Dict]: """Apply search filtering""" normalized_options = self.search_strategy.normalize_options(search_options) return self.search_strategy.apply(data, search, normalized_options, fuzzy_search) async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply model-specific filters - to be overridden by subclasses if needed""" return data def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: """Apply pagination to filtered data""" total_items = len(data) start_idx = (page - 1) * page_size end_idx = min(start_idx + page_size, total_items) return { 'items': data[start_idx:end_idx], 'total': total_items, 'page': page, 'page_size': page_size, 'total_pages': (total_items + page_size - 1) // page_size } @abstractmethod async def format_response(self, model_data: Dict) -> Dict: """Format model data for API response - must be implemented by subclasses""" pass # Common service methods that delegate to scanner async def get_top_tags(self, limit: int = 20) -> List[Dict]: """Get top tags sorted by frequency""" return await self.scanner.get_top_tags(limit) async def get_base_models(self, limit: int = 20) -> List[Dict]: """Get base models sorted by frequency""" return await self.scanner.get_base_models(limit) def has_hash(self, sha256: str) -> bool: """Check if a model with given hash exists""" return self.scanner.has_hash(sha256) def get_path_by_hash(self, sha256: str) -> Optional[str]: """Get file path for a model by its hash""" return self.scanner.get_path_by_hash(sha256) def get_hash_by_path(self, file_path: str) -> Optional[str]: """Get hash for a model by its file path""" return self.scanner.get_hash_by_path(file_path) async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False): """Trigger model scanning""" return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache) async def get_model_info_by_name(self, name: str): """Get model information by name""" return await self.scanner.get_model_info_by_name(name) def get_model_roots(self) -> List[str]: """Get model root directories""" return self.scanner.get_model_roots() def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict: """Filter relevant fields from CivitAI data""" if not data: return {} fields = ["id", "modelId", "name", "trainedWords"] if minimal else [ "id", "modelId", "name", "createdAt", "updatedAt", "publishedAt", "trainedWords", "baseModel", "description", "model", "images", "customImages", "creator" ] return {k: data[k] for k in fields if k in data} async def get_folder_tree(self, model_root: str) -> Dict: """Get hierarchical folder tree for a specific model root""" cache = await self.scanner.get_cached_data() # Build tree structure from folders tree = {} for folder in cache.folders: # Check if this folder belongs to the specified model root folder_belongs_to_root = False for root in self.scanner.get_model_roots(): if root == model_root: folder_belongs_to_root = True break if not folder_belongs_to_root: continue # Split folder path into components parts = folder.split('/') if folder else [] current_level = tree for part in parts: if part not in current_level: current_level[part] = {} current_level = current_level[part] return tree async def get_unified_folder_tree(self) -> Dict: """Get unified folder tree across all model roots""" cache = await self.scanner.get_cached_data() # Build unified tree structure by analyzing all relative paths unified_tree = {} # Get all model roots for path normalization model_roots = self.scanner.get_model_roots() for folder in cache.folders: if not folder: # Skip empty folders continue # Find which root this folder belongs to by checking the actual file paths # This is a simplified approach - we'll use the folder as-is since it should already be relative relative_path = folder # Split folder path into components parts = relative_path.split('/') current_level = unified_tree for part in parts: if part not in current_level: current_level[part] = {} current_level = current_level[part] return unified_tree async def get_model_notes(self, model_name: str) -> Optional[str]: """Get notes for a specific model file""" cache = await self.scanner.get_cached_data() for model in cache.raw_data: if model['file_name'] == model_name: return model.get('notes', '') return None async def get_model_preview_url(self, model_name: str) -> Optional[str]: """Get the static preview URL for a model file""" cache = await self.scanner.get_cached_data() for model in cache.raw_data: if model['file_name'] == model_name: preview_url = model.get('preview_url') if preview_url: from ..config import config return config.get_preview_static_url(preview_url) return '/loras_static/images/no-preview.png' async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]: """Get the Civitai URL for a model file""" cache = await self.scanner.get_cached_data() for model in cache.raw_data: if model['file_name'] == model_name: civitai_data = model.get('civitai', {}) model_id = civitai_data.get('modelId') version_id = civitai_data.get('id') if model_id: civitai_url = f"https://civitai.com/models/{model_id}" if version_id: civitai_url += f"?modelVersionId={version_id}" return { 'civitai_url': civitai_url, 'model_id': str(model_id), 'version_id': str(version_id) if version_id else None } return {'civitai_url': None, 'model_id': None, 'version_id': None} async def get_model_metadata(self, file_path: str) -> Optional[Dict]: """Get filtered CivitAI metadata for a model by file path""" cache = await self.scanner.get_cached_data() for model in cache.raw_data: if model.get('file_path') == file_path: return self.filter_civitai_data(model.get("civitai", {})) return None async def get_model_description(self, file_path: str) -> Optional[str]: """Get model description by file path""" cache = await self.scanner.get_cached_data() for model in cache.raw_data: if model.get('file_path') == file_path: return model.get('modelDescription', '') return None async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]: """Search model relative file paths for autocomplete functionality""" cache = await self.scanner.get_cached_data() matching_paths = [] search_lower = search_term.lower() # Get model roots for path calculation model_roots = self.scanner.get_model_roots() for model in cache.raw_data: file_path = model.get('file_path', '') if not file_path: continue # Calculate relative path from model root relative_path = None for root in model_roots: # Normalize paths for comparison normalized_root = os.path.normpath(root) normalized_file = os.path.normpath(file_path) if normalized_file.startswith(normalized_root): # Remove root and leading separator to get relative path relative_path = normalized_file[len(normalized_root):].lstrip(os.sep) break if relative_path and search_lower in relative_path.lower(): matching_paths.append(relative_path) if len(matching_paths) >= limit * 2: # Get more for better sorting break # Sort by relevance (exact matches first, then by length) matching_paths.sort(key=lambda x: ( not x.lower().startswith(search_lower), # Exact prefix matches first len(x), # Then by length (shorter first) x.lower() # Then alphabetically )) return matching_paths[:limit]