From 2d00cfdd319fb9d5750db77e7231cef680fc3c05 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 21 Sep 2025 23:13:30 +0800 Subject: [PATCH 01/24] refactor: enhance BaseModelService initialization and improve filtering logic --- py/services/base_model_service.py | 261 +++++++-------------- py/services/model_query.py | 196 ++++++++++++++++ tests/services/test_base_model_service.py | 269 ++++++++++++++++++++++ 3 files changed, 552 insertions(+), 174 deletions(-) create mode 100644 py/services/model_query.py create mode 100644 tests/services/test_base_model_service.py diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index ed1fc930..0b4aaf99 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -5,98 +5,88 @@ import os from ..utils.models import BaseModelMetadata from ..utils.routes_common import ModelRouteUtils -from ..utils.constants import NSFW_LEVELS -from .settings_manager import settings -from ..utils.utils import fuzzy_match +from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider +from .settings_manager import settings as default_settings logger = logging.getLogger(__name__) class BaseModelService(ABC): """Base service class for all model types""" - def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]): - """Initialize the service - + def __init__( + self, + model_type: str, + scanner, + metadata_class: Type[BaseModelMetadata], + *, + cache_repository: Optional[ModelCacheRepository] = None, + filter_set: Optional[ModelFilterSet] = None, + search_strategy: Optional[SearchStrategy] = None, + settings_provider: Optional[SettingsProvider] = None, + ): + """Initialize the service. + Args: - model_type: Type of model (lora, checkpoint, etc.) - scanner: Model scanner instance - metadata_class: Metadata class for this model type + model_type: Type of model (lora, checkpoint, etc.). + scanner: Model scanner instance. + metadata_class: Metadata class for this model type. + cache_repository: Custom repository for cache access (primarily for tests). + filter_set: Filter component controlling folder/tag/favorites logic. + search_strategy: Search component for fuzzy/text matching. + settings_provider: Settings object; defaults to the global settings manager. """ self.model_type = model_type self.scanner = scanner self.metadata_class = metadata_class + self.settings = settings_provider or default_settings + self.cache_repository = cache_repository or ModelCacheRepository(scanner) + self.filter_set = filter_set or ModelFilterSet(self.settings) + self.search_strategy = search_strategy or SearchStrategy() - async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', - folder: str = None, search: str = None, fuzzy_search: bool = False, - base_models: list = None, tags: list = None, - search_options: dict = None, hash_filters: dict = None, - favorites_only: bool = False, **kwargs) -> Dict: - """Get paginated and filtered model data - - Args: - page: Page number (1-based) - page_size: Number of items per page - sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc' - folder: Folder filter - search: Search term - fuzzy_search: Whether to use fuzzy search - base_models: List of base models to filter by - tags: List of tags to filter by - search_options: Search options dict - hash_filters: Hash filtering options - favorites_only: Filter for favorites only - **kwargs: Additional model-specific filters - - Returns: - Dict containing paginated results - """ - cache = await self.scanner.get_cached_data() + async def get_paginated_data( + self, + page: int, + page_size: int, + sort_by: str = 'name', + folder: str = None, + search: str = None, + fuzzy_search: bool = False, + base_models: list = None, + tags: list = None, + search_options: dict = None, + hash_filters: dict = None, + favorites_only: bool = False, + **kwargs, + ) -> Dict: + """Get paginated and filtered model data""" + sort_params = self.cache_repository.parse_sort(sort_by) + sorted_data = await self.cache_repository.fetch_sorted(sort_params) - # Parse sort_by into sort_key and order - if ':' in sort_by: - sort_key, order = sort_by.split(':', 1) - sort_key = sort_key.strip() - order = order.strip().lower() - if order not in ('asc', 'desc'): - order = 'asc' - else: - sort_key = sort_by.strip() - order = 'asc' - - # Get default search options if not provided - if search_options is None: - search_options = { - 'filename': True, - 'modelname': True, - 'tags': False, - 'recursive': True, - } - - # Get the base data set using new sort logic - filtered_data = await cache.get_sorted_data(sort_key, order) - - # Apply hash filtering if provided (highest priority) if hash_filters: - filtered_data = await self._apply_hash_filters(filtered_data, hash_filters) - - # Jump to pagination for hash filters + filtered_data = await self._apply_hash_filters(sorted_data, hash_filters) return self._paginate(filtered_data, page, page_size) - - # Apply common filters + filtered_data = await self._apply_common_filters( - filtered_data, folder, base_models, tags, favorites_only, search_options + sorted_data, + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=search_options, ) - - # Apply search filtering + if search: filtered_data = await self._apply_search_filters( - filtered_data, search, fuzzy_search, search_options + filtered_data, + search, + fuzzy_search, + search_options, ) - - # Apply model-specific filters + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) - + return self._paginate(filtered_data, page, page_size) + async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: """Apply hash-based filtering""" @@ -120,113 +110,36 @@ class BaseModelService(ABC): return data - async def _apply_common_filters(self, data: List[Dict], folder: str = None, - base_models: list = None, tags: list = None, - favorites_only: bool = False, search_options: dict = None) -> List[Dict]: + async def _apply_common_filters( + self, + data: List[Dict], + folder: str = None, + base_models: list = None, + tags: list = None, + favorites_only: bool = False, + search_options: dict = None, + ) -> List[Dict]: """Apply common filters that work across all model types""" - # Apply SFW filtering if enabled in settings - if settings.get('show_only_sfw', False): - data = [ - item for item in data - if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] - ] - - # Apply favorites filtering if enabled - if favorites_only: - data = [ - item for item in data - if item.get('favorite', False) is True - ] - - # Apply folder filtering - if folder is not None: - if search_options and search_options.get('recursive', True): - # Recursive folder filtering - include all subfolders - # Ensure we match exact folder or its subfolders by checking path boundaries - if folder == "": - # Empty folder means root - include all items - pass # Don't filter anything - else: - # Add trailing slash to ensure we match folder boundaries correctly - folder_with_separator = folder + "/" - data = [ - item for item in data - if (item['folder'] == folder or - item['folder'].startswith(folder_with_separator)) - ] - else: - # Exact folder filtering - data = [ - item for item in data - if item['folder'] == folder - ] - - # Apply base model filtering - if base_models and len(base_models) > 0: - data = [ - item for item in data - if item.get('base_model') in base_models - ] - - # Apply tag filtering - if tags and len(tags) > 0: - data = [ - item for item in data - if any(tag in item.get('tags', []) for tag in tags) - ] - - return data + normalized_options = self.search_strategy.normalize_options(search_options) + criteria = FilterCriteria( + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=normalized_options, + ) + return self.filter_set.apply(data, criteria) - async def _apply_search_filters(self, data: List[Dict], search: str, - fuzzy_search: bool, search_options: dict) -> List[Dict]: + async def _apply_search_filters( + self, + data: List[Dict], + search: str, + fuzzy_search: bool, + search_options: dict, + ) -> List[Dict]: """Apply search filtering""" - search_results = [] - - for item in data: - # Search by file name - if search_options.get('filename', True): - if fuzzy_search: - if fuzzy_match(item.get('file_name', ''), search): - search_results.append(item) - continue - elif search.lower() in item.get('file_name', '').lower(): - search_results.append(item) - continue - - # Search by model name - if search_options.get('modelname', True): - if fuzzy_search: - if fuzzy_match(item.get('model_name', ''), search): - search_results.append(item) - continue - elif search.lower() in item.get('model_name', '').lower(): - search_results.append(item) - continue - - # Search by tags - if search_options.get('tags', False) and 'tags' in item: - if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) - for tag in item['tags']): - search_results.append(item) - continue - - # Search by creator - civitai = item.get('civitai') - creator_username = '' - if civitai and isinstance(civitai, dict): - creator = civitai.get('creator') - if creator and isinstance(creator, dict): - creator_username = creator.get('username', '') - if search_options.get('creator', False) and creator_username: - if fuzzy_search: - if fuzzy_match(creator_username, search): - search_results.append(item) - continue - elif search.lower() in creator_username.lower(): - search_results.append(item) - continue - - return search_results + normalized_options = self.search_strategy.normalize_options(search_options) + return self.search_strategy.apply(data, search, normalized_options, fuzzy_search) async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply model-specific filters - to be overridden by subclasses if needed""" diff --git a/py/services/model_query.py b/py/services/model_query.py new file mode 100644 index 00000000..08ca652f --- /dev/null +++ b/py/services/model_query.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable + +from ..utils.constants import NSFW_LEVELS +from ..utils.utils import fuzzy_match as default_fuzzy_match + + +class SettingsProvider(Protocol): + """Protocol describing the SettingsManager contract used by query helpers.""" + + def get(self, key: str, default: Any = None) -> Any: + ... + + +@dataclass(frozen=True) +class SortParams: + """Normalized representation of sorting instructions.""" + + key: str + order: str + + +@dataclass(frozen=True) +class FilterCriteria: + """Container for model list filtering options.""" + + folder: Optional[str] = None + base_models: Optional[Sequence[str]] = None + tags: Optional[Sequence[str]] = None + favorites_only: bool = False + search_options: Optional[Dict[str, Any]] = None + + +class ModelCacheRepository: + """Adapter around scanner cache access and sort normalisation.""" + + def __init__(self, scanner) -> None: + self._scanner = scanner + + async def get_cache(self): + """Return the underlying cache instance from the scanner.""" + return await self._scanner.get_cached_data() + + async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]: + """Fetch cached data pre-sorted according to ``params``.""" + cache = await self.get_cache() + return await cache.get_sorted_data(params.key, params.order) + + @staticmethod + def parse_sort(sort_by: str) -> SortParams: + """Parse an incoming sort string into key/order primitives.""" + if not sort_by: + return SortParams(key="name", order="asc") + + if ":" in sort_by: + raw_key, raw_order = sort_by.split(":", 1) + sort_key = raw_key.strip().lower() or "name" + order = raw_order.strip().lower() + else: + sort_key = sort_by.strip().lower() or "name" + order = "asc" + + if order not in ("asc", "desc"): + order = "asc" + + return SortParams(key=sort_key, order=order) + + +class ModelFilterSet: + """Applies common filtering rules to the model collection.""" + + def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None: + self._settings = settings + self._nsfw_levels = nsfw_levels or NSFW_LEVELS + + def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]: + """Return items that satisfy the provided criteria.""" + items = list(data) + + if self._settings.get("show_only_sfw", False): + threshold = self._nsfw_levels.get("R", 0) + items = [ + item for item in items + if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold + ] + + if criteria.favorites_only: + items = [item for item in items if item.get("favorite", False)] + + folder = criteria.folder + options = criteria.search_options or {} + recursive = bool(options.get("recursive", True)) + if folder is not None: + if recursive: + if folder: + folder_with_sep = f"{folder}/" + items = [ + item for item in items + if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep) + ] + else: + items = [item for item in items if item.get("folder") == folder] + + base_models = criteria.base_models or [] + if base_models: + base_model_set = set(base_models) + items = [item for item in items if item.get("base_model") in base_model_set] + + tags = criteria.tags or [] + if tags: + tag_set = set(tags) + items = [ + item for item in items + if any(tag in tag_set for tag in item.get("tags", [])) + ] + + return items + + +class SearchStrategy: + """Encapsulates text and fuzzy matching behaviour for model queries.""" + + DEFAULT_OPTIONS: Dict[str, Any] = { + "filename": True, + "modelname": True, + "tags": False, + "recursive": True, + "creator": False, + } + + def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None: + self._fuzzy_match = fuzzy_matcher or default_fuzzy_match + + def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge provided options with defaults without mutating input.""" + normalized = dict(self.DEFAULT_OPTIONS) + if options: + normalized.update(options) + return normalized + + def apply( + self, + data: Iterable[Dict[str, Any]], + search_term: str, + options: Dict[str, Any], + fuzzy: bool = False, + ) -> List[Dict[str, Any]]: + """Return items matching the search term using the configured strategy.""" + if not search_term: + return list(data) + + search_lower = search_term.lower() + results: List[Dict[str, Any]] = [] + + for item in data: + if options.get("filename", True): + candidate = item.get("file_name", "") + if self._matches(candidate, search_term, search_lower, fuzzy): + results.append(item) + continue + + if options.get("modelname", True): + candidate = item.get("model_name", "") + if self._matches(candidate, search_term, search_lower, fuzzy): + results.append(item) + continue + + if options.get("tags", False): + tags = item.get("tags", []) or [] + if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags): + results.append(item) + continue + + if options.get("creator", False): + creator_username = "" + civitai = item.get("civitai") + if isinstance(civitai, dict): + creator = civitai.get("creator") + if isinstance(creator, dict): + creator_username = creator.get("username", "") + if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy): + results.append(item) + continue + + return results + + def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool: + if not candidate: + return False + + candidate_lower = candidate.lower() + if fuzzy: + return self._fuzzy_match(candidate, search_term) + return search_lower in candidate_lower diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py new file mode 100644 index 00000000..4acfcc49 --- /dev/null +++ b/tests/services/test_base_model_service.py @@ -0,0 +1,269 @@ +import pytest + +from py.services.base_model_service import BaseModelService +from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams +from py.utils.models import BaseModelMetadata + + +class StubSettings: + def __init__(self, values): + self._values = dict(values) + + def get(self, key, default=None): + return self._values.get(key, default) + + +class DummyService(BaseModelService): + async def format_response(self, model_data): + return model_data + + +class StubRepository: + def __init__(self, data): + self._data = list(data) + self.parse_sort_calls = [] + self.fetch_sorted_calls = [] + + def parse_sort(self, sort_by): + params = ModelCacheRepository.parse_sort(sort_by) + self.parse_sort_calls.append(sort_by) + return params + + async def fetch_sorted(self, params): + self.fetch_sorted_calls.append(params) + return list(self._data) + + +class StubFilterSet: + def __init__(self, result): + self.result = list(result) + self.calls = [] + + def apply(self, data, criteria): + self.calls.append((list(data), criteria)) + return list(self.result) + + +class StubSearchStrategy: + def __init__(self, search_result): + self.search_result = list(search_result) + self.normalize_calls = [] + self.apply_calls = [] + + def normalize_options(self, options): + self.normalize_calls.append(options) + normalized = {"recursive": True} + if options: + normalized.update(options) + return normalized + + def apply(self, data, search_term, options, fuzzy): + self.apply_calls.append((list(data), search_term, options, fuzzy)) + return list(self.search_result) + + +@pytest.mark.asyncio +async def test_get_paginated_data_uses_injected_collaborators(): + data = [ + {"model_name": "Alpha", "folder": "root"}, + {"model_name": "Beta", "folder": "root"}, + ] + repository = StubRepository(data) + filter_set = StubFilterSet([{"model_name": "Filtered"}]) + search_strategy = StubSearchStrategy([{"model_name": "SearchResult"}]) + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=5, + sort_by="name:desc", + folder="root", + search="query", + fuzzy_search=True, + base_models=["base"], + tags=["tag"], + search_options={"recursive": False}, + favorites_only=True, + ) + + assert repository.parse_sort_calls == ["name:desc"] + assert repository.fetch_sorted_calls and isinstance(repository.fetch_sorted_calls[0], SortParams) + sort_params = repository.fetch_sorted_calls[0] + assert sort_params.key == "name" and sort_params.order == "desc" + + assert filter_set.calls, "FilterSet should be invoked" + call_data, criteria = filter_set.calls[0] + assert call_data == data + assert criteria.folder == "root" + assert criteria.base_models == ["base"] + assert criteria.tags == ["tag"] + assert criteria.favorites_only is True + assert criteria.search_options.get("recursive") is False + + assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}] + assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)] + + assert response["items"] == search_strategy.search_result + assert response["total"] == len(search_strategy.search_result) + assert response["page"] == 1 + assert response["page_size"] == 5 + + +class FakeCache: + def __init__(self, items): + self.items = list(items) + + async def get_sorted_data(self, sort_key, order): + if sort_key == "name": + data = sorted(self.items, key=lambda x: x["model_name"].lower()) + if order == "desc": + data.reverse() + else: + data = list(self.items) + return data + + +class FakeScanner: + def __init__(self, cache): + self._cache = cache + + async def get_cached_data(self, *_, **__): + return self._cache + + +@pytest.mark.asyncio +async def test_get_paginated_data_filters_and_searches_combination(): + items = [ + { + "model_name": "Alpha", + "file_name": "alpha.safetensors", + "folder": "root/sub", + "tags": ["tag1"], + "base_model": "v1", + "favorite": True, + "preview_nsfw_level": 0, + }, + { + "model_name": "Beta", + "file_name": "beta.safetensors", + "folder": "root", + "tags": ["tag2"], + "base_model": "v2", + "favorite": False, + "preview_nsfw_level": 999, + }, + { + "model_name": "Gamma", + "file_name": "gamma.safetensors", + "folder": "root/sub2", + "tags": ["tag1", "tag3"], + "base_model": "v1", + "favorite": True, + "preview_nsfw_level": 0, + "civitai": {"creator": {"username": "artist"}}, + }, + ] + + cache = FakeCache(items) + scanner = FakeScanner(cache) + settings = StubSettings({"show_only_sfw": True}) + + service = DummyService( + model_type="stub", + scanner=scanner, + metadata_class=BaseModelMetadata, + cache_repository=ModelCacheRepository(scanner), + filter_set=ModelFilterSet(settings), + search_strategy=SearchStrategy(), + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=1, + sort_by="name:asc", + folder="root", + search="artist", + base_models=["v1"], + tags=["tag1"], + search_options={"creator": True, "tags": True}, + favorites_only=True, + ) + + assert response["items"] == [items[2]] + assert response["total"] == 1 + assert response["page"] == 1 + assert response["page_size"] == 1 + assert response["total_pages"] == 1 + + +class PassThroughFilterSet: + def __init__(self): + self.calls = [] + + def apply(self, data, criteria): + self.calls.append(criteria) + return list(data) + + +class NoSearchStrategy: + def __init__(self): + self.normalize_calls = [] + self.apply_called = False + + def normalize_options(self, options): + self.normalize_calls.append(options) + return {"recursive": True} + + def apply(self, *args, **kwargs): + self.apply_called = True + pytest.fail("Search should not be invoked when no search term is provided") + + +@pytest.mark.asyncio +async def test_get_paginated_data_paginates_without_search(): + items = [ + {"model_name": name, "folder": "root"} + for name in ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"] + ] + + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=2, + page_size=2, + sort_by="name:asc", + ) + + assert repository.parse_sort_calls == ["name:asc"] + assert len(repository.fetch_sorted_calls) == 1 + assert filter_set.calls and filter_set.calls[0].favorites_only is False + assert search_strategy.apply_called is False + assert response["items"] == items[2:4] + assert response["total"] == len(items) + assert response["page"] == 2 + assert response["page_size"] == 2 + assert response["total_pages"] == 3 From 21772feaddad2a6393ac79bc7432f6ae458dea35 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Sun, 21 Sep 2025 23:34:46 +0800 Subject: [PATCH 02/24] refactor(routes): extract route utilities into services --- py/routes/base_model_routes.py | 45 ++- py/routes/handlers/model_handlers.py | 230 +++++++++++- py/services/download_coordinator.py | 100 +++++ py/services/metadata_sync_service.py | 355 ++++++++++++++++++ py/services/preview_asset_service.py | 168 +++++++++ py/services/tag_update_service.py | 47 +++ py/utils/example_images_metadata.py | 36 +- pytest.ini | 3 + tests/conftest.py | 9 + tests/services/test_base_model_service.py | 33 +- tests/services/test_route_support_services.py | 273 ++++++++++++++ 11 files changed, 1269 insertions(+), 30 deletions(-) create mode 100644 py/services/download_coordinator.py create mode 100644 py/services/metadata_sync_service.py create mode 100644 py/services/preview_asset_service.py create mode 100644 py/services/tag_update_service.py create mode 100644 tests/services/test_route_support_services.py diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 458a5e87..65103ece 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -8,12 +8,20 @@ import jinja2 from aiohttp import web from ..config import config -from ..services.metadata_service import get_default_metadata_provider +from ..services.download_coordinator import DownloadCoordinator +from ..services.downloader import get_downloader +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider +from ..services.metadata_sync_service import MetadataSyncService from ..services.model_file_service import ModelFileService, ModelMoveService +from ..services.preview_asset_service import PreviewAssetService +from ..services.server_i18n import server_i18n as default_server_i18n +from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings as default_settings +from ..services.tag_update_service import TagUpdateService from ..services.websocket_manager import ws_manager as default_ws_manager from ..services.websocket_progress_callback import WebSocketProgressCallback -from ..services.server_i18n import server_i18n as default_server_i18n +from ..utils.exif_utils import ExifUtils +from ..utils.metadata_manager import MetadataManager from ..utils.routes_common import ModelRouteUtils from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar from .handlers.model_handlers import ( @@ -64,6 +72,24 @@ class BaseModelRoutes(ABC): self._handler_set: ModelHandlerSet | None = None self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None + self._preview_service = PreviewAssetService( + metadata_manager=MetadataManager, + downloader_factory=get_downloader, + exif_utils=ExifUtils, + ) + self._metadata_sync_service = MetadataSyncService( + metadata_manager=MetadataManager, + preview_service=self._preview_service, + settings=settings_service, + default_metadata_provider_factory=metadata_provider_factory, + metadata_provider_selector=get_metadata_provider, + ) + self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager) + self._download_coordinator = DownloadCoordinator( + ws_manager=self._ws_manager, + download_manager_factory=ServiceRegistry.get_download_manager, + ) + if service is not None: self.attach_service(service) @@ -98,9 +124,19 @@ class BaseModelRoutes(ABC): parse_specific_params=self._parse_specific_params, logger=logger, ) - management = ModelManagementHandler(service=service, logger=logger) + management = ModelManagementHandler( + service=service, + logger=logger, + metadata_sync=self._metadata_sync_service, + preview_service=self._preview_service, + tag_update_service=self._tag_update_service, + ) query = ModelQueryHandler(service=service, logger=logger) - download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger) + download = ModelDownloadHandler( + ws_manager=self._ws_manager, + logger=logger, + download_coordinator=self._download_coordinator, + ) civitai = ModelCivitaiHandler( service=service, settings_service=self._settings, @@ -110,6 +146,7 @@ class BaseModelRoutes(ABC): validate_model_type=self._validate_civitai_model_type, expected_model_types=self._get_expected_model_types, find_model_file=self._find_model_file, + metadata_sync=self._metadata_sync_service, ) move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger) auto_organize = ModelAutoOrganizeHandler( diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 66a7123a..5f9eaf3b 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -4,16 +4,23 @@ from __future__ import annotations import asyncio import json import logging +import os from dataclasses import dataclass from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional from aiohttp import web import jinja2 +from ...config import config +from ...services.download_coordinator import DownloadCoordinator +from ...services.metadata_sync_service import MetadataSyncService from ...services.model_file_service import ModelFileService, ModelMoveService -from ...services.websocket_progress_callback import WebSocketProgressCallback -from ...services.websocket_manager import WebSocketManager +from ...services.preview_asset_service import PreviewAssetService from ...services.settings_manager import SettingsManager +from ...services.tag_update_service import TagUpdateService +from ...services.websocket_manager import WebSocketManager +from ...services.websocket_progress_callback import WebSocketProgressCallback +from ...utils.file_utils import calculate_sha256 from ...utils.routes_common import ModelRouteUtils @@ -168,9 +175,20 @@ class ModelListingHandler: class ModelManagementHandler: """Handle mutation operations on models.""" - def __init__(self, *, service, logger: logging.Logger) -> None: + def __init__( + self, + *, + service, + logger: logging.Logger, + metadata_sync: MetadataSyncService, + preview_service: PreviewAssetService, + tag_update_service: TagUpdateService, + ) -> None: self._service = service self._logger = logger + self._metadata_sync = metadata_sync + self._preview_service = preview_service + self._tag_update_service = tag_update_service async def delete_model(self, request: web.Request) -> web.Response: return await ModelRouteUtils.handle_delete_model(request, self._service.scanner) @@ -192,7 +210,7 @@ class ModelManagementHandler: if not model_data.get("sha256"): return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) - success, error = await ModelRouteUtils.fetch_and_update_model( + success, error = await self._metadata_sync.fetch_and_update_model( sha256=model_data["sha256"], file_path=file_path, model_data=model_data, @@ -208,16 +226,144 @@ class ModelManagementHandler: return web.json_response({"success": False, "error": str(exc)}, status=500) async def relink_civitai(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + model_id = data.get("model_id") + model_version_id = data.get("model_version_id") + + if not file_path or model_id is None: + return web.json_response( + {"success": False, "error": "Both file_path and model_id are required"}, + status=400, + ) + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + local_metadata = await self._metadata_sync.load_local_metadata(metadata_path) + + updated_metadata = await self._metadata_sync.relink_metadata( + file_path=file_path, + metadata=local_metadata, + model_id=int(model_id), + model_version_id=int(model_version_id) if model_version_id else None, + ) + + await self._service.scanner.update_single_model_cache( + file_path, file_path, updated_metadata + ) + + message = ( + f"Model successfully re-linked to Civitai model {model_id}" + + (f" version {model_version_id}" if model_version_id else "") + ) + return web.json_response( + {"success": True, "message": message, "hash": updated_metadata.get("sha256", "")} + ) + except Exception as exc: + self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def replace_preview(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner) + try: + reader = await request.multipart() + + field = await reader.next() + if field is None or field.name != "preview_file": + raise ValueError("Expected 'preview_file' field") + content_type = field.headers.get("Content-Type", "image/png") + content_disposition = field.headers.get("Content-Disposition", "") + + original_filename = None + import re + + match = re.search(r'filename="(.*?)"', content_disposition) + if match: + original_filename = match.group(1) + + preview_data = await field.read() + + field = await reader.next() + if field is None or field.name != "model_path": + raise ValueError("Expected 'model_path' field") + model_path = (await field.read()).decode() + + nsfw_level = 0 + field = await reader.next() + if field and field.name == "nsfw_level": + try: + nsfw_level = int((await field.read()).decode()) + except (ValueError, TypeError): + self._logger.warning("Invalid NSFW level format, using default 0") + + result = await self._preview_service.replace_preview( + model_path=model_path, + preview_data=preview_data, + content_type=content_type, + original_filename=original_filename, + nsfw_level=nsfw_level, + update_preview_in_cache=self._service.scanner.update_preview_in_cache, + metadata_loader=self._metadata_sync.load_local_metadata, + ) + + return web.json_response( + { + "success": True, + "preview_url": config.get_preview_static_url(result["preview_path"]), + "preview_nsfw_level": result["preview_nsfw_level"], + } + ) + except Exception as exc: + self._logger.error("Error replacing preview: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def save_metadata(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="File path is required", status=400) + + metadata_updates = {k: v for k, v in data.items() if k != "file_path"} + + await self._metadata_sync.save_metadata_updates( + file_path=file_path, + updates=metadata_updates, + metadata_loader=self._metadata_sync.load_local_metadata, + update_cache=self._service.scanner.update_single_model_cache, + ) + + if "model_name" in metadata_updates: + cache = await self._service.scanner.get_cached_data() + await cache.resort() + + return web.json_response({"success": True}) + except Exception as exc: + self._logger.error("Error saving metadata: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def add_tags(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_add_tags(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + new_tags = data.get("tags", []) + + if not file_path: + return web.Response(text="File path is required", status=400) + + if not isinstance(new_tags, list): + return web.Response(text="Tags must be a list", status=400) + + tags = await self._tag_update_service.add_tags( + file_path=file_path, + new_tags=new_tags, + metadata_loader=self._metadata_sync.load_local_metadata, + update_cache=self._service.scanner.update_single_model_cache, + ) + + return web.json_response({"success": True, "tags": tags}) + except Exception as exc: + self._logger.error("Error adding tags: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def rename_model(self, request: web.Request) -> web.Response: return await ModelRouteUtils.handle_rename_model(request, self._service.scanner) @@ -226,7 +372,27 @@ class ModelManagementHandler: return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner) async def verify_duplicates(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner) + try: + data = await request.json() + file_paths = data.get("file_paths", []) + + if not file_paths: + return web.json_response( + {"success": False, "error": "No file paths provided for verification"}, + status=400, + ) + + results = await self._metadata_sync.verify_duplicate_hashes( + file_paths=file_paths, + metadata_loader=self._metadata_sync.load_local_metadata, + hash_calculator=calculate_sha256, + update_cache=self._service.scanner.update_single_model_cache, + ) + + return web.json_response({"success": True, **results}) + except Exception as exc: + self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) class ModelQueryHandler: @@ -429,12 +595,39 @@ class ModelQueryHandler: class ModelDownloadHandler: """Coordinate downloads and progress reporting.""" - def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None: + def __init__( + self, + *, + ws_manager: WebSocketManager, + logger: logging.Logger, + download_coordinator: DownloadCoordinator, + ) -> None: self._ws_manager = ws_manager self._logger = logger + self._download_coordinator = download_coordinator async def download_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request) + try: + payload = await request.json() + result = await self._download_coordinator.schedule_download(payload) + if not result.get("success", False): + return web.json_response(result, status=500) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + error_message = str(exc) + if "401" in error_message: + self._logger.warning("Early access error (401): %s", error_message) + return web.json_response( + { + "success": False, + "error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.", + }, + status=401, + ) + self._logger.error("Error downloading model: %s", error_message) + return web.json_response({"success": False, "error": error_message}, status=500) async def download_model_get(self, request: web.Request) -> web.Response: try: @@ -460,7 +653,12 @@ class ModelDownloadHandler: future.set_result(data) mock_request = type("MockRequest", (), {"json": lambda self=None: future})() - return await ModelRouteUtils.handle_download_model(mock_request) + result = await self._download_coordinator.schedule_download(data) + if not result.get("success", False): + return web.json_response(result, status=500) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) except Exception as exc: self._logger.error("Error downloading model via GET: %s", exc, exc_info=True) return web.Response(status=500, text=str(exc)) @@ -470,8 +668,8 @@ class ModelDownloadHandler: download_id = request.query.get("download_id") if not download_id: return web.json_response({"success": False, "error": "Download ID is required"}, status=400) - mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})() - return await ModelRouteUtils.handle_cancel_download(mock_request) + result = await self._download_coordinator.cancel_download(download_id) + return web.json_response(result) except Exception as exc: self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -504,6 +702,7 @@ class ModelCivitaiHandler: validate_model_type: Callable[[str], bool], expected_model_types: Callable[[], str], find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]], + metadata_sync: MetadataSyncService, ) -> None: self._service = service self._settings = settings_service @@ -513,6 +712,7 @@ class ModelCivitaiHandler: self._validate_model_type = validate_model_type self._expected_model_types = expected_model_types self._find_model_file = find_model_file + self._metadata_sync = metadata_sync async def fetch_all_civitai(self, request: web.Request) -> web.Response: try: @@ -545,7 +745,7 @@ class ModelCivitaiHandler: for model in to_process: try: original_name = model.get("model_name") - result, error = await ModelRouteUtils.fetch_and_update_model( + result, error = await self._metadata_sync.fetch_and_update_model( sha256=model["sha256"], file_path=model["file_path"], model_data=model, diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py new file mode 100644 index 00000000..4cf866e5 --- /dev/null +++ b/py/services/download_coordinator.py @@ -0,0 +1,100 @@ +"""Service wrapper for coordinating download lifecycle events.""" + +from __future__ import annotations + +import logging +from typing import Any, Awaitable, Callable, Dict, Optional + + +logger = logging.getLogger(__name__) + + +class DownloadCoordinator: + """Manage download scheduling, cancellation and introspection.""" + + def __init__( + self, + *, + ws_manager, + download_manager_factory: Callable[[], Awaitable], + ) -> None: + self._ws_manager = ws_manager + self._download_manager_factory = download_manager_factory + + async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Schedule a download using the provided payload.""" + + download_manager = await self._download_manager_factory() + + download_id = payload.get("download_id") or self._ws_manager.generate_download_id() + payload.setdefault("download_id", download_id) + + async def progress_callback(progress: Any) -> None: + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "progress", + "progress": progress, + "download_id": download_id, + }, + ) + + model_id = self._parse_optional_int(payload.get("model_id"), "model_id") + model_version_id = self._parse_optional_int( + payload.get("model_version_id"), "model_version_id" + ) + + if model_id is None and model_version_id is None: + raise ValueError( + "Missing required parameter: Please provide either 'model_id' or 'model_version_id'" + ) + + result = await download_manager.download_from_civitai( + model_id=model_id, + model_version_id=model_version_id, + save_dir=payload.get("model_root"), + relative_path=payload.get("relative_path", ""), + use_default_paths=payload.get("use_default_paths", False), + progress_callback=progress_callback, + download_id=download_id, + source=payload.get("source"), + ) + + result["download_id"] = download_id + return result + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + """Cancel an active download and emit a broadcast event.""" + + download_manager = await self._download_manager_factory() + result = await download_manager.cancel_download(download_id) + + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "cancelled", + "progress": 0, + "download_id": download_id, + "message": "Download cancelled by user", + }, + ) + + return result + + async def list_active_downloads(self) -> Dict[str, Any]: + """Return the active download map from the underlying manager.""" + + download_manager = await self._download_manager_factory() + return await download_manager.get_active_downloads() + + def _parse_optional_int(self, value: Any, field: str) -> Optional[int]: + """Parse an optional integer from user input.""" + + if value is None or value == "": + return None + + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid {field}: Must be an integer") from exc + diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py new file mode 100644 index 00000000..aaf2f248 --- /dev/null +++ b/py/services/metadata_sync_service.py @@ -0,0 +1,355 @@ +"""Services for synchronising metadata with remote providers.""" + +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, Iterable, Optional + +from ..services.settings_manager import SettingsManager +from ..utils.model_utils import determine_base_model + +logger = logging.getLogger(__name__) + + +class MetadataProviderProtocol: + """Subset of metadata provider interface consumed by the sync service.""" + + async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]: + ... + + async def get_model_version( + self, model_id: int, model_version_id: Optional[int] + ) -> Optional[Dict[str, Any]]: + ... + + +class MetadataSyncService: + """High level orchestration for metadata synchronisation flows.""" + + def __init__( + self, + *, + metadata_manager, + preview_service, + settings: SettingsManager, + default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]], + metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]], + ) -> None: + self._metadata_manager = metadata_manager + self._preview_service = preview_service + self._settings = settings + self._get_default_provider = default_metadata_provider_factory + self._get_provider = metadata_provider_selector + + async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]: + """Load metadata JSON from disk, returning an empty structure when missing.""" + + if not os.path.exists(metadata_path): + return {} + + try: + with open(metadata_path, "r", encoding="utf-8") as handle: + return json.load(handle) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error loading metadata from %s: %s", metadata_path, exc) + return {} + + async def mark_not_found_on_civitai( + self, metadata_path: str, local_metadata: Dict[str, Any] + ) -> None: + """Persist the not-found flag for a metadata payload.""" + + local_metadata["from_civitai"] = False + await self._metadata_manager.save_metadata(metadata_path, local_metadata) + + @staticmethod + def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool: + """Determine if the metadata originated from the CivitAI public API.""" + + if not isinstance(meta, dict): + return False + files = meta.get("files") + images = meta.get("images") + source = meta.get("source") + return bool(files) and bool(images) and source != "archive_db" + + async def update_model_metadata( + self, + metadata_path: str, + local_metadata: Dict[str, Any], + civitai_metadata: Dict[str, Any], + metadata_provider: Optional[MetadataProviderProtocol] = None, + ) -> Dict[str, Any]: + """Merge remote metadata into the local record and persist the result.""" + + existing_civitai = local_metadata.get("civitai") or {} + + if ( + civitai_metadata.get("source") == "archive_db" + and self.is_civitai_api_metadata(existing_civitai) + ): + logger.info( + "Skip civitai update for %s (%s)", + local_metadata.get("model_name", ""), + existing_civitai.get("name", ""), + ) + else: + merged_civitai = existing_civitai.copy() + merged_civitai.update(civitai_metadata) + + if civitai_metadata.get("source") == "archive_db": + model_name = civitai_metadata.get("model", {}).get("name", "") + version_name = civitai_metadata.get("name", "") + logger.info( + "Recovered metadata from archive_db for deleted model: %s (%s)", + model_name, + version_name, + ) + + if "trainedWords" in existing_civitai: + existing_trained = existing_civitai.get("trainedWords", []) + new_trained = civitai_metadata.get("trainedWords", []) + merged_trained = list(set(existing_trained + new_trained)) + merged_civitai["trainedWords"] = merged_trained + + local_metadata["civitai"] = merged_civitai + + if "model" in civitai_metadata and civitai_metadata["model"]: + model_data = civitai_metadata["model"] + + if model_data.get("name"): + local_metadata["model_name"] = model_data["name"] + + if not local_metadata.get("modelDescription") and model_data.get("description"): + local_metadata["modelDescription"] = model_data["description"] + + if not local_metadata.get("tags") and model_data.get("tags"): + local_metadata["tags"] = model_data["tags"] + + if model_data.get("creator") and not local_metadata.get("civitai", {}).get( + "creator" + ): + local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"] + + local_metadata["base_model"] = determine_base_model( + civitai_metadata.get("baseModel") + ) + + await self._preview_service.ensure_preview_for_metadata( + metadata_path, local_metadata, civitai_metadata.get("images", []) + ) + + await self._metadata_manager.save_metadata(metadata_path, local_metadata) + return local_metadata + + async def fetch_and_update_model( + self, + *, + sha256: str, + file_path: str, + model_data: Dict[str, Any], + update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> tuple[bool, Optional[str]]: + """Fetch metadata for a model and update both disk and cache state.""" + + if not isinstance(model_data, dict): + error = f"Invalid model_data type: {type(model_data)}" + logger.error(error) + return False, error + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + enable_archive = self._settings.get("enable_metadata_archive_db", False) + + try: + if model_data.get("civitai_deleted") is True: + if not enable_archive or model_data.get("db_checked") is True: + return ( + False, + "CivitAI model is deleted and metadata archive DB is not enabled", + ) + metadata_provider = await self._get_provider("sqlite") + else: + metadata_provider = await self._get_default_provider() + + civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256) + if not civitai_metadata: + if error == "Model not found": + model_data["from_civitai"] = False + model_data["civitai_deleted"] = True + model_data["db_checked"] = enable_archive + model_data["last_checked_at"] = datetime.now().timestamp() + + data_to_save = model_data.copy() + data_to_save.pop("folder", None) + await self._metadata_manager.save_metadata(file_path, data_to_save) + + error_msg = ( + f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})" + ) + logger.error(error_msg) + return False, error_msg + + model_data["from_civitai"] = True + model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db" + model_data["db_checked"] = enable_archive + model_data["last_checked_at"] = datetime.now().timestamp() + + local_metadata = model_data.copy() + local_metadata.pop("folder", None) + + await self.update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + metadata_provider, + ) + + update_payload = { + "model_name": local_metadata.get("model_name"), + "preview_url": local_metadata.get("preview_url"), + "civitai": local_metadata.get("civitai"), + } + model_data.update(update_payload) + + await update_cache_func(file_path, file_path, local_metadata) + return True, None + except KeyError as exc: + error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}" + logger.error(error_msg) + return False, error_msg + except Exception as exc: # pragma: no cover - error path + error_msg = f"Error fetching metadata: {exc}" + logger.error(error_msg, exc_info=True) + return False, error_msg + + async def fetch_metadata_by_sha( + self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None + ) -> tuple[Optional[Dict[str, Any]], Optional[str]]: + """Fetch metadata for a SHA256 hash from the configured provider.""" + + provider = metadata_provider or await self._get_default_provider() + return await provider.get_model_by_hash(sha256) + + async def relink_metadata( + self, + *, + file_path: str, + metadata: Dict[str, Any], + model_id: int, + model_version_id: Optional[int], + ) -> Dict[str, Any]: + """Relink a local metadata record to a specific CivitAI model version.""" + + provider = await self._get_default_provider() + civitai_metadata = await provider.get_model_version(model_id, model_version_id) + if not civitai_metadata: + raise ValueError( + f"Model version not found on CivitAI for ID: {model_id}" + + (f" with version: {model_version_id}" if model_version_id else "") + ) + + primary_model_file: Optional[Dict[str, Any]] = None + for file_info in civitai_metadata.get("files", []): + if file_info.get("primary", False) and file_info.get("type") == "Model": + primary_model_file = file_info + break + + if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"): + metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower() + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + await self.update_model_metadata( + metadata_path, + metadata, + civitai_metadata, + provider, + ) + + return metadata + + async def save_metadata_updates( + self, + *, + file_path: str, + updates: Dict[str, Any], + metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]], + update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> Dict[str, Any]: + """Apply metadata updates and persist to disk and cache.""" + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + + for key, value in updates.items(): + if isinstance(value, dict) and isinstance(metadata.get(key), dict): + metadata[key].update(value) + else: + metadata[key] = value + + await self._metadata_manager.save_metadata(file_path, metadata) + await update_cache(file_path, file_path, metadata) + + if "model_name" in updates: + logger.debug("Metadata update touched model_name; cache resort required") + + return metadata + + async def verify_duplicate_hashes( + self, + *, + file_paths: Iterable[str], + metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]], + hash_calculator: Callable[[str], Awaitable[str]], + update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> Dict[str, Any]: + """Verify a collection of files share the same SHA256 hash.""" + + file_paths = list(file_paths) + if not file_paths: + raise ValueError("No file paths provided for verification") + + results = { + "verified_as_duplicates": True, + "mismatched_files": [], + "new_hash_map": {}, + } + + expected_hash: Optional[str] = None + first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json" + first_metadata = await metadata_loader(first_metadata_path) + if first_metadata and "sha256" in first_metadata: + expected_hash = first_metadata["sha256"].lower() + + for path in file_paths: + if not os.path.exists(path): + continue + + try: + actual_hash = await hash_calculator(path) + metadata_path = os.path.splitext(path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + stored_hash = metadata.get("sha256", "").lower() + + if not expected_hash: + expected_hash = stored_hash + + if actual_hash != expected_hash: + results["verified_as_duplicates"] = False + results["mismatched_files"].append(path) + results["new_hash_map"][path] = actual_hash + + if actual_hash != stored_hash: + metadata["sha256"] = actual_hash + await self._metadata_manager.save_metadata(path, metadata) + await update_cache(path, path, metadata) + except Exception as exc: # pragma: no cover - defensive path + logger.error("Error verifying hash for %s: %s", path, exc) + results["mismatched_files"].append(path) + results["new_hash_map"][path] = "error_calculating_hash" + results["verified_as_duplicates"] = False + + return results + diff --git a/py/services/preview_asset_service.py b/py/services/preview_asset_service.py new file mode 100644 index 00000000..42baadac --- /dev/null +++ b/py/services/preview_asset_service.py @@ -0,0 +1,168 @@ +"""Service for processing preview assets for models.""" + +from __future__ import annotations + +import logging +import os +from typing import Awaitable, Callable, Dict, Optional, Sequence + +from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS + +logger = logging.getLogger(__name__) + + +class PreviewAssetService: + """Manage fetching and persisting preview assets.""" + + def __init__( + self, + *, + metadata_manager, + downloader_factory: Callable[[], Awaitable], + exif_utils, + ) -> None: + self._metadata_manager = metadata_manager + self._downloader_factory = downloader_factory + self._exif_utils = exif_utils + + async def ensure_preview_for_metadata( + self, + metadata_path: str, + local_metadata: Dict[str, object], + images: Sequence[Dict[str, object]] | None, + ) -> None: + """Ensure preview assets exist for the supplied metadata entry.""" + + if local_metadata.get("preview_url") and os.path.exists( + str(local_metadata["preview_url"]) + ): + return + + if not images: + return + + first_preview = images[0] + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_dir = os.path.dirname(metadata_path) + is_video = first_preview.get("type") == "video" + + if is_video: + extension = ".mp4" + preview_path = os.path.join(preview_dir, base_name + extension) + downloader = await self._downloader_factory() + success, result = await downloader.download_file( + first_preview["url"], preview_path, use_auth=False + ) + if success: + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + else: + extension = ".webp" + preview_path = os.path.join(preview_dir, base_name + extension) + downloader = await self._downloader_factory() + success, content, _headers = await downloader.download_to_memory( + first_preview["url"], use_auth=False + ) + if not success: + return + + try: + optimized_data, _ = self._exif_utils.optimize_image( + image_data=content, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=False, + ) + with open(preview_path, "wb") as handle: + handle.write(optimized_data) + except Exception as exc: # pragma: no cover - defensive path + logger.error("Error optimizing preview image: %s", exc) + try: + with open(preview_path, "wb") as handle: + handle.write(content) + except Exception as save_exc: + logger.error("Error saving preview image: %s", save_exc) + return + + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + + async def replace_preview( + self, + *, + model_path: str, + preview_data: bytes, + content_type: str, + original_filename: Optional[str], + nsfw_level: int, + update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]], + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + ) -> Dict[str, object]: + """Replace an existing preview asset for a model.""" + + base_name = os.path.splitext(os.path.basename(model_path))[0] + folder = os.path.dirname(model_path) + + extension, optimized_data = await self._convert_preview( + preview_data, content_type, original_filename + ) + + for ext in PREVIEW_EXTENSIONS: + existing_preview = os.path.join(folder, base_name + ext) + if os.path.exists(existing_preview): + try: + os.remove(existing_preview) + except Exception as exc: # pragma: no cover - defensive path + logger.warning( + "Failed to delete existing preview %s: %s", existing_preview, exc + ) + + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/") + with open(preview_path, "wb") as handle: + handle.write(optimized_data) + + metadata_path = os.path.splitext(model_path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + metadata["preview_url"] = preview_path + metadata["preview_nsfw_level"] = nsfw_level + await self._metadata_manager.save_metadata(model_path, metadata) + + await update_preview_in_cache(model_path, preview_path, nsfw_level) + + return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level} + + async def _convert_preview( + self, data: bytes, content_type: str, original_filename: Optional[str] + ) -> tuple[str, bytes]: + """Convert preview bytes to the persisted representation.""" + + if content_type.startswith("video/"): + extension = self._resolve_video_extension(content_type, original_filename) + return extension, data + + original_ext = (original_filename or "").lower() + if original_ext.endswith(".gif") or content_type.lower() == "image/gif": + return ".gif", data + + optimized_data, _ = self._exif_utils.optimize_image( + image_data=data, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=False, + ) + return ".webp", optimized_data + + def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str: + """Infer the best extension for a video preview.""" + + if original_filename: + extension = os.path.splitext(original_filename)[1].lower() + if extension in {".mp4", ".webm", ".mov", ".avi"}: + return extension + + if "webm" in content_type: + return ".webm" + return ".mp4" + diff --git a/py/services/tag_update_service.py b/py/services/tag_update_service.py new file mode 100644 index 00000000..d560e7d6 --- /dev/null +++ b/py/services/tag_update_service.py @@ -0,0 +1,47 @@ +"""Service for updating tag collections on metadata records.""" + +from __future__ import annotations + +import os + +from typing import Awaitable, Callable, Dict, List, Sequence + + +class TagUpdateService: + """Encapsulate tag manipulation for models.""" + + def __init__(self, *, metadata_manager) -> None: + self._metadata_manager = metadata_manager + + async def add_tags( + self, + *, + file_path: str, + new_tags: Sequence[str], + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]], + ) -> List[str]: + """Add tags to a metadata entry while keeping case-insensitive uniqueness.""" + + base, _ = os.path.splitext(file_path) + metadata_path = f"{base}.metadata.json" + metadata = await metadata_loader(metadata_path) + + existing_tags = list(metadata.get("tags", [])) + existing_lower = [tag.lower() for tag in existing_tags] + + tags_added: List[str] = [] + for tag in new_tags: + if isinstance(tag, str) and tag.strip(): + normalized = tag.strip() + if normalized.lower() not in existing_lower: + existing_tags.append(normalized) + existing_lower.append(normalized.lower()) + tags_added.append(normalized) + + metadata["tags"] = existing_tags + await self._metadata_manager.save_metadata(file_path, metadata) + await update_cache(file_path, file_path, metadata) + + return existing_tags + diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 71566bff..66db05a3 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -1,14 +1,34 @@ import logging import os import re -from ..utils.metadata_manager import MetadataManager -from ..utils.routes_common import ModelRouteUtils + +from ..recipes.constants import GEN_PARAM_KEYS +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider +from ..services.metadata_sync_service import MetadataSyncService +from ..services.preview_asset_service import PreviewAssetService +from ..services.settings_manager import settings +from ..services.downloader import get_downloader from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..utils.exif_utils import ExifUtils -from ..recipes.constants import GEN_PARAM_KEYS +from ..utils.metadata_manager import MetadataManager logger = logging.getLogger(__name__) +_preview_service = PreviewAssetService( + metadata_manager=MetadataManager, + downloader_factory=get_downloader, + exif_utils=ExifUtils, +) + +_metadata_sync_service = MetadataSyncService( + metadata_manager=MetadataManager, + preview_service=_preview_service, + settings=settings, + default_metadata_provider_factory=get_default_metadata_provider, + metadata_provider_selector=get_metadata_provider, +) + + class MetadataUpdater: """Handles updating model metadata related to example images""" @@ -53,11 +73,11 @@ class MetadataUpdater: async def update_cache_func(old_path, new_path, metadata): return await scanner.update_single_model_cache(old_path, new_path, metadata) - success, error = await ModelRouteUtils.fetch_and_update_model( - model_hash, - file_path, - model_data, - update_cache_func + success, error = await _metadata_sync_service.fetch_and_update_model( + sha256=model_hash, + file_path=file_path, + model_data=model_data, + update_cache_func=update_cache_func, ) if success: diff --git a/pytest.ini b/pytest.ini index 44f4dc04..6f82885c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,5 +4,8 @@ testpaths = tests python_files = test_*.py python_classes = Test* python_functions = test_* +# Register async marker for coroutine-style tests +markers = + asyncio: execute test within asyncio event loop # Skip problematic directories to avoid import conflicts norecursedirs = .git .tox dist build *.egg __pycache__ py \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index dfe99691..818aeb9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import types from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence +import asyncio +import inspect from unittest import mock import sys @@ -39,6 +41,13 @@ nodes_mock.NODE_CLASS_MAPPINGS = {} sys.modules['nodes'] = nodes_mock +def pytest_pyfunc_call(pyfuncitem): + if inspect.iscoroutinefunction(pyfuncitem.function): + asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs)) + return True + return None + + @dataclass class MockHashIndex: """Minimal hash index stub mirroring the scanner contract.""" diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 4acfcc49..c3fdc884 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -1,8 +1,35 @@ import pytest -from py.services.base_model_service import BaseModelService -from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams -from py.utils.models import BaseModelMetadata +import importlib +import importlib.util +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def import_from(module_name: str): + existing = sys.modules.get("py") + if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): + sys.modules.pop("py", None) + spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py") + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) # type: ignore[union-attr] + module.__path__ = [str(ROOT / "py")] + sys.modules["py"] = module + return importlib.import_module(module_name) + + +BaseModelService = import_from("py.services.base_model_service").BaseModelService +model_query_module = import_from("py.services.model_query") +ModelCacheRepository = model_query_module.ModelCacheRepository +ModelFilterSet = model_query_module.ModelFilterSet +SearchStrategy = model_query_module.SearchStrategy +SortParams = model_query_module.SortParams +BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata class StubSettings: diff --git a/tests/services/test_route_support_services.py b/tests/services/test_route_support_services.py new file mode 100644 index 00000000..978438c3 --- /dev/null +++ b/tests/services/test_route_support_services.py @@ -0,0 +1,273 @@ +import asyncio +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict, List + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +import importlib +import importlib.util + +import pytest + + +def import_from(module_name: str): + existing = sys.modules.get("py") + if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): + sys.modules.pop("py", None) + spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py") + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) # type: ignore[union-attr] + module.__path__ = [str(ROOT / "py")] + sys.modules["py"] = module + return importlib.import_module(module_name) + + +DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator +MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService +PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService +TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService + + +class DummySettings: + def __init__(self, values: Dict[str, Any] | None = None) -> None: + self._values = values or {} + + def get(self, key: str, default: Any = None) -> Any: + return self._values.get(key, default) + + +class RecordingMetadataManager: + def __init__(self) -> None: + self.saved: List[tuple[str, Dict[str, Any]]] = [] + + async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool: + self.saved.append((path, json.loads(json.dumps(metadata)))) + metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json" + Path(metadata_path).write_text(json.dumps(metadata)) + return True + + +class RecordingPreviewService: + def __init__(self) -> None: + self.calls: List[tuple[str, List[Dict[str, Any]]]] = [] + + async def ensure_preview_for_metadata( + self, metadata_path: str, local_metadata: Dict[str, Any], images + ) -> None: + self.calls.append((metadata_path, list(images or []))) + local_metadata["preview_url"] = "preview.webp" + local_metadata["preview_nsfw_level"] = 1 + + +class DummyProvider: + def __init__(self, payload: Dict[str, Any]) -> None: + self.payload = payload + + async def get_model_by_hash(self, sha256: str): + return self.payload, None + + async def get_model_version(self, model_id: int, model_version_id: int | None): + return self.payload + + +class FakeExifUtils: + @staticmethod + def optimize_image(**kwargs): + return kwargs["image_data"], {} + + +def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None: + manager = RecordingMetadataManager() + preview = RecordingPreviewService() + provider = DummyProvider({ + "baseModel": "SD15", + "model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}}, + "trainedWords": ["word"], + "images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}], + }) + + service = MetadataSyncService( + metadata_manager=manager, + preview_service=preview, + settings=DummySettings(), + default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider), + metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider), + ) + + metadata_path = str(tmp_path / "model.metadata.json") + local_metadata = {"civitai": {"trainedWords": ["existing"]}} + + updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload)) + + assert updated["model_name"] == "Merged" + assert updated["modelDescription"] == "desc" + assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"} + assert manager.saved + assert preview.calls + + +def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None: + manager = RecordingMetadataManager() + preview = RecordingPreviewService() + provider = DummyProvider({ + "baseModel": "SDXL", + "model": {"name": "Updated"}, + "images": [], + }) + + update_cache_calls: List[Dict[str, Any]] = [] + + async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool: + update_cache_calls.append({"original": original, "metadata": metadata}) + return True + + service = MetadataSyncService( + metadata_manager=manager, + preview_service=preview, + settings=DummySettings(), + default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider), + metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider), + ) + + model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")} + success, error = asyncio.run( + service.fetch_and_update_model( + sha256="abc", + file_path=str(tmp_path / "model.safetensors"), + model_data=model_data, + update_cache_func=update_cache, + ) + ) + + assert success is True + assert error is None + assert update_cache_calls + assert manager.saved + + +def test_preview_asset_service_replace_preview(tmp_path: Path) -> None: + metadata_path = tmp_path / "sample.metadata.json" + metadata_path.write_text(json.dumps({})) + + async def metadata_loader(path: str) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + + manager = RecordingMetadataManager() + + service = PreviewAssetService( + metadata_manager=manager, + downloader_factory=lambda: asyncio.sleep(0, result=None), + exif_utils=FakeExifUtils(), + ) + + preview_calls: List[Dict[str, Any]] = [] + + async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool: + preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw}) + return True + + model_path = str(tmp_path / "sample.safetensors") + Path(model_path).write_bytes(b"model") + + result = asyncio.run( + service.replace_preview( + model_path=model_path, + preview_data=b"image-bytes", + content_type="image/png", + original_filename="preview.png", + nsfw_level=2, + update_preview_in_cache=update_preview, + metadata_loader=metadata_loader, + ) + ) + + assert result["preview_nsfw_level"] == 2 + assert preview_calls + saved_metadata = json.loads(metadata_path.read_text()) + assert saved_metadata["preview_nsfw_level"] == 2 + + +def test_download_coordinator_emits_progress() -> None: + class WSStub: + def __init__(self) -> None: + self.progress_events: List[Dict[str, Any]] = [] + self.counter = 0 + + def generate_download_id(self) -> str: + self.counter += 1 + return f"dl-{self.counter}" + + async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None: + self.progress_events.append({"id": download_id, **payload}) + + class DownloadManagerStub: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def download_from_civitai(self, **kwargs) -> Dict[str, Any]: + self.calls.append(kwargs) + await kwargs["progress_callback"](10) + return {"success": True} + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + return {"success": True, "download_id": download_id} + + async def get_active_downloads(self) -> Dict[str, Any]: + return {"active": []} + + ws_stub = WSStub() + manager_stub = DownloadManagerStub() + + coordinator = DownloadCoordinator( + ws_manager=ws_stub, + download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub), + ) + + result = asyncio.run(coordinator.schedule_download({"model_id": 1})) + + assert result["success"] is True + assert manager_stub.calls + assert ws_stub.progress_events + + cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"])) + assert cancel_result["success"] is True + + active = asyncio.run(coordinator.list_active_downloads()) + assert active == {"active": []} + + +def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None: + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text(json.dumps({"tags": ["Existing"]})) + + async def loader(path: str) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + + manager = RecordingMetadataManager() + + service = TagUpdateService(metadata_manager=manager) + + cache_updates: List[Dict[str, Any]] = [] + + async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool: + cache_updates.append(metadata) + return True + + tags = asyncio.run( + service.add_tags( + file_path=str(tmp_path / "model.safetensors"), + new_tags=["New", "existing"], + metadata_loader=loader, + update_cache=update_cache, + ) + ) + + assert tags == ["Existing", "New"] + assert manager.saved + assert cache_updates From 8cf99dd928a69ecf3176f9e12d0c4ebce51dd339 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 21 Sep 2025 23:39:21 +0800 Subject: [PATCH 03/24] refactor(tests): remove deprecated test runner script --- run_tests.py | 50 -------------------------------------------------- 1 file changed, 50 deletions(-) delete mode 100644 run_tests.py diff --git a/run_tests.py b/run_tests.py deleted file mode 100644 index af5f96ff..00000000 --- a/run_tests.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -""" -Test runner script for ComfyUI-Lora-Manager. - -This script runs pytest from the tests directory to avoid import issues -with the root __init__.py file. -""" -import subprocess -import sys -import os -from pathlib import Path - -# Set environment variable to indicate standalone mode -# HF_HUB_DISABLE_TELEMETRY is from ComfyUI main.py -standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" - -def main(): - """Run pytest from the tests directory to avoid import issues.""" - # Get the script directory - script_dir = Path(__file__).parent.absolute() - tests_dir = script_dir / "tests" - - if not tests_dir.exists(): - print(f"Error: Tests directory not found at {tests_dir}") - return 1 - - # Change to tests directory - original_cwd = os.getcwd() - os.chdir(tests_dir) - - try: - # Build pytest command - cmd = [ - sys.executable, "-m", "pytest", - "-v", - "--rootdir=.", - ] + sys.argv[1:] # Pass any additional arguments - - print(f"Running: {' '.join(cmd)}") - print(f"Working directory: {tests_dir}") - - # Run pytest - result = subprocess.run(cmd, cwd=tests_dir) - return result.returncode - finally: - # Restore original working directory - os.chdir(original_cwd) - -if __name__ == "__main__": - sys.exit(main()) From c063854b511be529e8a4ef8b70c28b2d557d7db3 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 05:25:27 +0800 Subject: [PATCH 04/24] feat(routes): extract orchestration use cases --- py/routes/base_model_routes.py | 27 ++- py/routes/handlers/model_handlers.py | 153 +++++--------- py/services/use_cases/__init__.py | 25 +++ .../use_cases/auto_organize_use_case.py | 56 +++++ .../bulk_metadata_refresh_use_case.py | 122 +++++++++++ .../use_cases/download_model_use_case.py | 37 ++++ py/services/websocket_progress_callback.py | 30 ++- tests/routes/test_base_model_routes_smoke.py | 80 ++++++++ tests/services/test_use_cases.py | 191 ++++++++++++++++++ 9 files changed, 609 insertions(+), 112 deletions(-) create mode 100644 py/services/use_cases/__init__.py create mode 100644 py/services/use_cases/auto_organize_use_case.py create mode 100644 py/services/use_cases/bulk_metadata_refresh_use_case.py create mode 100644 py/services/use_cases/download_model_use_case.py create mode 100644 tests/services/test_use_cases.py diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 65103ece..872dca8b 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -19,7 +19,15 @@ from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings as default_settings from ..services.tag_update_service import TagUpdateService from ..services.websocket_manager import ws_manager as default_ws_manager -from ..services.websocket_progress_callback import WebSocketProgressCallback +from ..services.use_cases import ( + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelUseCase, +) +from ..services.websocket_progress_callback import ( + WebSocketBroadcastCallback, + WebSocketProgressCallback, +) from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from ..utils.routes_common import ModelRouteUtils @@ -68,6 +76,7 @@ class BaseModelRoutes(ABC): self.model_file_service: ModelFileService | None = None self.model_move_service: ModelMoveService | None = None self.websocket_progress_callback = WebSocketProgressCallback() + self.metadata_progress_callback = WebSocketBroadcastCallback() self._handler_set: ModelHandlerSet | None = None self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None @@ -132,11 +141,19 @@ class BaseModelRoutes(ABC): tag_update_service=self._tag_update_service, ) query = ModelQueryHandler(service=service, logger=logger) + download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator) download = ModelDownloadHandler( ws_manager=self._ws_manager, logger=logger, + download_use_case=download_use_case, download_coordinator=self._download_coordinator, ) + metadata_refresh_use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=self._metadata_sync_service, + settings_service=self._settings, + logger=logger, + ) civitai = ModelCivitaiHandler( service=service, settings_service=self._settings, @@ -147,10 +164,16 @@ class BaseModelRoutes(ABC): expected_model_types=self._get_expected_model_types, find_model_file=self._find_model_file, metadata_sync=self._metadata_sync_service, + metadata_refresh_use_case=metadata_refresh_use_case, + metadata_progress_callback=self.metadata_progress_callback, ) move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger) - auto_organize = ModelAutoOrganizeHandler( + auto_organize_use_case = AutoOrganizeUseCase( file_service=self._ensure_file_service(), + lock_provider=self._ws_manager, + ) + auto_organize = ModelAutoOrganizeHandler( + use_case=auto_organize_use_case, progress_callback=self.websocket_progress_callback, ws_manager=self._ws_manager, logger=logger, diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 5f9eaf3b..ba0628c6 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -14,10 +14,19 @@ import jinja2 from ...config import config from ...services.download_coordinator import DownloadCoordinator from ...services.metadata_sync_service import MetadataSyncService -from ...services.model_file_service import ModelFileService, ModelMoveService +from ...services.model_file_service import ModelMoveService from ...services.preview_asset_service import PreviewAssetService from ...services.settings_manager import SettingsManager from ...services.tag_update_service import TagUpdateService +from ...services.use_cases import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, + MetadataRefreshProgressReporter, +) from ...services.websocket_manager import WebSocketManager from ...services.websocket_progress_callback import WebSocketProgressCallback from ...utils.file_utils import calculate_sha256 @@ -600,33 +609,29 @@ class ModelDownloadHandler: *, ws_manager: WebSocketManager, logger: logging.Logger, + download_use_case: DownloadModelUseCase, download_coordinator: DownloadCoordinator, ) -> None: self._ws_manager = ws_manager self._logger = logger + self._download_use_case = download_use_case self._download_coordinator = download_coordinator async def download_model(self, request: web.Request) -> web.Response: try: payload = await request.json() - result = await self._download_coordinator.schedule_download(payload) + result = await self._download_use_case.execute(payload) if not result.get("success", False): return web.json_response(result, status=500) return web.json_response(result) - except ValueError as exc: + except DownloadModelValidationError as exc: return web.json_response({"success": False, "error": str(exc)}, status=400) + except DownloadModelEarlyAccessError as exc: + self._logger.warning("Early access error: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=401) except Exception as exc: error_message = str(exc) - if "401" in error_message: - self._logger.warning("Early access error (401): %s", error_message) - return web.json_response( - { - "success": False, - "error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.", - }, - status=401, - ) - self._logger.error("Error downloading model: %s", error_message) + self._logger.error("Error downloading model: %s", error_message, exc_info=True) return web.json_response({"success": False, "error": error_message}, status=500) async def download_model_get(self, request: web.Request) -> web.Response: @@ -653,12 +658,15 @@ class ModelDownloadHandler: future.set_result(data) mock_request = type("MockRequest", (), {"json": lambda self=None: future})() - result = await self._download_coordinator.schedule_download(data) + result = await self._download_use_case.execute(data) if not result.get("success", False): return web.json_response(result, status=500) return web.json_response(result) - except ValueError as exc: + except DownloadModelValidationError as exc: return web.json_response({"success": False, "error": str(exc)}, status=400) + except DownloadModelEarlyAccessError as exc: + self._logger.warning("Early access error: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=401) except Exception as exc: self._logger.error("Error downloading model via GET: %s", exc, exc_info=True) return web.Response(status=500, text=str(exc)) @@ -703,6 +711,8 @@ class ModelCivitaiHandler: expected_model_types: Callable[[], str], find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]], metadata_sync: MetadataSyncService, + metadata_refresh_use_case: BulkMetadataRefreshUseCase, + metadata_progress_callback: MetadataRefreshProgressReporter, ) -> None: self._service = service self._settings = settings_service @@ -713,75 +723,16 @@ class ModelCivitaiHandler: self._expected_model_types = expected_model_types self._find_model_file = find_model_file self._metadata_sync = metadata_sync + self._metadata_refresh_use_case = metadata_refresh_use_case + self._metadata_progress_callback = metadata_progress_callback async def fetch_all_civitai(self, request: web.Request) -> web.Response: try: - cache = await self._service.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) - to_process = [ - model - for model in cache.raw_data - if model.get("sha256") - and (not model.get("civitai") or not model["civitai"].get("id")) - and ( - (enable_metadata_archive_db and not model.get("db_checked", False)) - or (not enable_metadata_archive_db and model.get("from_civitai") is True) - ) - ] - total_to_process = len(to_process) - - await self._ws_manager.broadcast({ - "status": "started", - "total": total_to_process, - "processed": 0, - "success": 0, - }) - - for model in to_process: - try: - original_name = model.get("model_name") - result, error = await self._metadata_sync.fetch_and_update_model( - sha256=model["sha256"], - file_path=model["file_path"], - model_data=model, - update_cache_func=self._service.scanner.update_single_model_cache, - ) - if result: - success += 1 - if original_name != model.get("model_name"): - needs_resort = True - processed += 1 - await self._ws_manager.broadcast({ - "status": "processing", - "total": total_to_process, - "processed": processed, - "success": success, - "current_name": model.get("model_name", "Unknown"), - }) - except Exception as exc: # pragma: no cover - logging path - self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc) - - if needs_resort: - await cache.resort() - - await self._ws_manager.broadcast({ - "status": "completed", - "total": total_to_process, - "processed": processed, - "success": success, - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})", - }) + result = await self._metadata_refresh_use_case.execute_with_error_handling( + progress_callback=self._metadata_progress_callback + ) + return web.json_response(result) except Exception as exc: - await self._ws_manager.broadcast({"status": "error", "error": str(exc)}) self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc) return web.Response(text=str(exc), status=500) @@ -887,31 +838,18 @@ class ModelAutoOrganizeHandler: def __init__( self, *, - file_service: ModelFileService, + use_case: AutoOrganizeUseCase, progress_callback: WebSocketProgressCallback, ws_manager: WebSocketManager, logger: logging.Logger, ) -> None: - self._file_service = file_service + self._use_case = use_case self._progress_callback = progress_callback self._ws_manager = ws_manager self._logger = logger async def auto_organize_models(self, request: web.Request) -> web.Response: try: - if self._ws_manager.is_auto_organize_running(): - return web.json_response( - {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, - status=409, - ) - - auto_organize_lock = await self._ws_manager.get_auto_organize_lock() - if auto_organize_lock.locked(): - return web.json_response( - {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, - status=409, - ) - file_paths = None if request.method == "POST": try: @@ -920,17 +858,24 @@ class ModelAutoOrganizeHandler: except Exception: # pragma: no cover - permissive path pass - async with auto_organize_lock: - result = await self._file_service.auto_organize_models( - file_paths=file_paths, - progress_callback=self._progress_callback, - ) - return web.json_response(result.to_dict()) + result = await self._use_case.execute( + file_paths=file_paths, + progress_callback=self._progress_callback, + ) + return web.json_response(result.to_dict()) + except AutoOrganizeInProgressError: + return web.json_response( + {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, + status=409, + ) except Exception as exc: self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True) - await self._ws_manager.broadcast_auto_organize_progress( - {"type": "auto_organize_progress", "status": "error", "error": str(exc)} - ) + try: + await self._progress_callback.on_progress( + {"type": "auto_organize_progress", "status": "error", "error": str(exc)} + ) + except Exception: # pragma: no cover - defensive reporting + pass return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_auto_organize_progress(self, request: web.Request) -> web.Response: diff --git a/py/services/use_cases/__init__.py b/py/services/use_cases/__init__.py new file mode 100644 index 00000000..986f0f57 --- /dev/null +++ b/py/services/use_cases/__init__.py @@ -0,0 +1,25 @@ +"""Application-level orchestration services for model routes.""" + +from .auto_organize_use_case import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, +) +from .bulk_metadata_refresh_use_case import ( + BulkMetadataRefreshUseCase, + MetadataRefreshProgressReporter, +) +from .download_model_use_case import ( + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, +) + +__all__ = [ + "AutoOrganizeInProgressError", + "AutoOrganizeUseCase", + "BulkMetadataRefreshUseCase", + "MetadataRefreshProgressReporter", + "DownloadModelEarlyAccessError", + "DownloadModelUseCase", + "DownloadModelValidationError", +] diff --git a/py/services/use_cases/auto_organize_use_case.py b/py/services/use_cases/auto_organize_use_case.py new file mode 100644 index 00000000..0914739f --- /dev/null +++ b/py/services/use_cases/auto_organize_use_case.py @@ -0,0 +1,56 @@ +"""Auto-organize use case orchestrating concurrency and progress handling.""" + +from __future__ import annotations + +import asyncio +from typing import Optional, Protocol, Sequence + +from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback + + +class AutoOrganizeLockProvider(Protocol): + """Minimal protocol for objects exposing auto-organize locking primitives.""" + + def is_auto_organize_running(self) -> bool: + """Return ``True`` when an auto-organize operation is in-flight.""" + + async def get_auto_organize_lock(self) -> asyncio.Lock: + """Return the asyncio lock guarding auto-organize operations.""" + + +class AutoOrganizeInProgressError(RuntimeError): + """Raised when an auto-organize run is already active.""" + + +class AutoOrganizeUseCase: + """Coordinate auto-organize execution behind a shared lock.""" + + def __init__( + self, + *, + file_service: ModelFileService, + lock_provider: AutoOrganizeLockProvider, + ) -> None: + self._file_service = file_service + self._lock_provider = lock_provider + + async def execute( + self, + *, + file_paths: Optional[Sequence[str]] = None, + progress_callback: Optional[ProgressCallback] = None, + ) -> AutoOrganizeResult: + """Run the auto-organize routine guarded by a shared lock.""" + + if self._lock_provider.is_auto_organize_running(): + raise AutoOrganizeInProgressError("Auto-organize is already running") + + lock = await self._lock_provider.get_auto_organize_lock() + if lock.locked(): + raise AutoOrganizeInProgressError("Auto-organize is already running") + + async with lock: + return await self._file_service.auto_organize_models( + file_paths=list(file_paths) if file_paths is not None else None, + progress_callback=progress_callback, + ) diff --git a/py/services/use_cases/bulk_metadata_refresh_use_case.py b/py/services/use_cases/bulk_metadata_refresh_use_case.py new file mode 100644 index 00000000..6a809955 --- /dev/null +++ b/py/services/use_cases/bulk_metadata_refresh_use_case.py @@ -0,0 +1,122 @@ +"""Use case encapsulating the bulk metadata refresh orchestration.""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional, Protocol, Sequence + +from ..metadata_sync_service import MetadataSyncService + + +class MetadataRefreshProgressReporter(Protocol): + """Protocol for progress reporters used during metadata refresh.""" + + async def on_progress(self, payload: Dict[str, Any]) -> None: + """Handle a metadata refresh progress update.""" + + +class BulkMetadataRefreshUseCase: + """Coordinate bulk metadata refreshes with progress emission.""" + + def __init__( + self, + *, + service, + metadata_sync: MetadataSyncService, + settings_service, + logger: Optional[logging.Logger] = None, + ) -> None: + self._service = service + self._metadata_sync = metadata_sync + self._settings = settings_service + self._logger = logger or logging.getLogger(__name__) + + async def execute( + self, + *, + progress_callback: Optional[MetadataRefreshProgressReporter] = None, + ) -> Dict[str, Any]: + """Refresh metadata for all qualifying models.""" + + cache = await self._service.scanner.get_cached_data() + total_models = len(cache.raw_data) + + enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) + to_process: Sequence[Dict[str, Any]] = [ + model + for model in cache.raw_data + if model.get("sha256") + and (not model.get("civitai") or not model["civitai"].get("id")) + and ( + (enable_metadata_archive_db and not model.get("db_checked", False)) + or (not enable_metadata_archive_db and model.get("from_civitai") is True) + ) + ] + + total_to_process = len(to_process) + processed = 0 + success = 0 + needs_resort = False + + async def emit(status: str, **extra: Any) -> None: + if progress_callback is None: + return + payload = {"status": status, "total": total_to_process, "processed": processed, "success": success} + payload.update(extra) + await progress_callback.on_progress(payload) + + await emit("started") + + for model in to_process: + try: + original_name = model.get("model_name") + result, _ = await self._metadata_sync.fetch_and_update_model( + sha256=model["sha256"], + file_path=model["file_path"], + model_data=model, + update_cache_func=self._service.scanner.update_single_model_cache, + ) + if result: + success += 1 + if original_name != model.get("model_name"): + needs_resort = True + processed += 1 + await emit( + "processing", + processed=processed, + success=success, + current_name=model.get("model_name", "Unknown"), + ) + except Exception as exc: # pragma: no cover - logging path + processed += 1 + self._logger.error( + "Error fetching CivitAI data for %s: %s", + model.get("file_path"), + exc, + ) + + if needs_resort: + await cache.resort() + + await emit("completed", processed=processed, success=success) + + message = ( + "Successfully updated " + f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})" + ) + + return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models} + + async def execute_with_error_handling( + self, + *, + progress_callback: Optional[MetadataRefreshProgressReporter] = None, + ) -> Dict[str, Any]: + """Wrapper providing progress notification on unexpected failures.""" + + try: + return await self.execute(progress_callback=progress_callback) + except Exception as exc: + if progress_callback is not None: + await progress_callback.on_progress({"status": "error", "error": str(exc)}) + raise diff --git a/py/services/use_cases/download_model_use_case.py b/py/services/use_cases/download_model_use_case.py new file mode 100644 index 00000000..5aa25bda --- /dev/null +++ b/py/services/use_cases/download_model_use_case.py @@ -0,0 +1,37 @@ +"""Use case for scheduling model downloads with consistent error handling.""" + +from __future__ import annotations + +from typing import Any, Dict + +from ..download_coordinator import DownloadCoordinator + + +class DownloadModelValidationError(ValueError): + """Raised when incoming payload validation fails.""" + + +class DownloadModelEarlyAccessError(RuntimeError): + """Raised when the download is gated behind Civitai early access.""" + + +class DownloadModelUseCase: + """Coordinate download scheduling through the coordinator service.""" + + def __init__(self, *, download_coordinator: DownloadCoordinator) -> None: + self._download_coordinator = download_coordinator + + async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Schedule a download and normalize error conditions.""" + + try: + return await self._download_coordinator.schedule_download(payload) + except ValueError as exc: + raise DownloadModelValidationError(str(exc)) from exc + except Exception as exc: # pragma: no cover - defensive logging path + message = str(exc) + if "401" in message: + raise DownloadModelEarlyAccessError( + "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) from exc + raise diff --git a/py/services/websocket_progress_callback.py b/py/services/websocket_progress_callback.py index 1a390f30..21423044 100644 --- a/py/services/websocket_progress_callback.py +++ b/py/services/websocket_progress_callback.py @@ -1,11 +1,29 @@ -from typing import Dict, Any +"""Progress callback implementations backed by the shared WebSocket manager.""" + +from typing import Any, Dict, Protocol + from .model_file_service import ProgressCallback from .websocket_manager import ws_manager -class WebSocketProgressCallback(ProgressCallback): - """WebSocket implementation of progress callback""" - +class ProgressReporter(Protocol): + """Protocol representing an async progress callback.""" + async def on_progress(self, progress_data: Dict[str, Any]) -> None: - """Send progress data via WebSocket""" - await ws_manager.broadcast_auto_organize_progress(progress_data) \ No newline at end of file + """Handle a progress update payload.""" + + +class WebSocketProgressCallback(ProgressCallback): + """WebSocket implementation of progress callback.""" + + async def on_progress(self, progress_data: Dict[str, Any]) -> None: + """Send progress data via WebSocket.""" + await ws_manager.broadcast_auto_organize_progress(progress_data) + + +class WebSocketBroadcastCallback: + """Generic WebSocket progress callback broadcasting to all clients.""" + + async def on_progress(self, progress_data: Dict[str, Any]) -> None: + """Send the provided payload to all connected clients.""" + await ws_manager.broadcast(progress_data) diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 2b9ed805..25ebaabc 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -28,6 +28,7 @@ spec.loader.exec_module(py_local) sys.modules.setdefault("py_local", py_local) from py_local.routes.base_model_routes import BaseModelRoutes +from py_local.services.model_file_service import AutoOrganizeResult from py_local.services.service_registry import ServiceRegistry from py_local.services.websocket_manager import ws_manager from py_local.utils.routes_common import ExifUtils @@ -222,6 +223,25 @@ def test_download_model_invokes_download_manager( asyncio.run(scenario()) +def test_download_model_requires_identifier(mock_service, download_manager_stub): + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_root": "/tmp"}, + ) + payload = await response.json() + + assert response.status == 400 + assert payload["success"] is False + assert "Missing required" in payload["error"] + finally: + await client.close() + + asyncio.run(scenario()) + + def test_auto_organize_progress_returns_latest_snapshot(mock_service): async def scenario(): client = await create_test_client(mock_service) @@ -235,5 +255,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service): assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}} finally: await client.close() + + asyncio.run(scenario()) + + +def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch): + async def fake_auto_organize(self, file_paths=None, progress_callback=None): + result = AutoOrganizeResult() + result.total = 1 + result.processed = 1 + result.success_count = 1 + result.skipped_count = 0 + result.failure_count = 0 + result.operation_type = "bulk" + if progress_callback is not None: + await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"}) + await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"}) + return result + + monkeypatch.setattr( + py_local.services.model_file_service.ModelFileService, + "auto_organize_models", + fake_auto_organize, + ) + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []}) + payload = await response.json() + + assert response.status == 200 + assert payload["success"] is True + + progress = ws_manager.get_auto_organize_progress() + assert progress is not None + assert progress["status"] == "completed" + finally: + await client.close() + + asyncio.run(scenario()) + + +def test_auto_organize_conflict_when_running(mock_service): + async def scenario(): + client = await create_test_client(mock_service) + try: + await ws_manager.broadcast_auto_organize_progress( + {"type": "auto_organize_progress", "status": "started"} + ) + + response = await client.post("/api/lm/test-models/auto-organize") + payload = await response.json() + + assert response.status == 409 + assert payload == { + "success": False, + "error": "Auto-organize is already running. Please wait for it to complete.", + } + finally: + await client.close() asyncio.run(scenario()) diff --git a/tests/services/test_use_cases.py b/tests/services/test_use_cases.py new file mode 100644 index 00000000..64057fc6 --- /dev/null +++ b/tests/services/test_use_cases.py @@ -0,0 +1,191 @@ +import asyncio +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import pytest + +from py_local.services.model_file_service import AutoOrganizeResult +from py_local.services.use_cases import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, +) +from tests.conftest import MockModelService, MockScanner + + +class StubLockProvider: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self.running = False + + def is_auto_organize_running(self) -> bool: + return self.running + + async def get_auto_organize_lock(self) -> asyncio.Lock: + return self._lock + + +class StubFileService: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def auto_organize_models( + self, + *, + file_paths: Optional[List[str]] = None, + progress_callback=None, + ) -> AutoOrganizeResult: + result = AutoOrganizeResult() + result.total = len(file_paths or []) + self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback}) + return result + + +class StubMetadataSync: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def fetch_and_update_model(self, **kwargs: Any): + self.calls.append(kwargs) + model_data = kwargs["model_data"] + model_data["model_name"] = model_data.get("model_name", "model") + "-updated" + return True, None + + +@dataclass +class StubSettings: + enable_metadata_archive_db: bool = False + + def get(self, key: str, default: Any = None) -> Any: + if key == "enable_metadata_archive_db": + return self.enable_metadata_archive_db + return default + + +class ProgressCollector: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + async def on_progress(self, payload: Dict[str, Any]) -> None: + self.events.append(payload) + + +class StubDownloadCoordinator: + def __init__(self, *, error: Optional[str] = None) -> None: + self.error = error + self.payloads: List[Dict[str, Any]] = [] + + async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + self.payloads.append(payload) + if self.error == "validation": + raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'") + if self.error == "401": + raise RuntimeError("401 Unauthorized") + return {"success": True, "download_id": "abc123"} + + +async def test_auto_organize_use_case_executes_with_lock() -> None: + file_service = StubFileService() + lock_provider = StubLockProvider() + use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider) + + result = await use_case.execute(file_paths=["model1"], progress_callback=None) + + assert isinstance(result, AutoOrganizeResult) + assert file_service.calls[0]["file_paths"] == ["model1"] + + +async def test_auto_organize_use_case_rejects_when_running() -> None: + file_service = StubFileService() + lock_provider = StubLockProvider() + lock_provider.running = True + use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider) + + with pytest.raises(AutoOrganizeInProgressError): + await use_case.execute(file_paths=None, progress_callback=None) + + +async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None: + scanner = MockScanner() + scanner._cache.raw_data = [ + { + "file_path": "model1.safetensors", + "sha256": "hash", + "from_civitai": True, + "model_name": "Demo", + } + ] + service = MockModelService(scanner) + metadata_sync = StubMetadataSync() + settings = StubSettings() + progress = ProgressCollector() + + use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=metadata_sync, + settings_service=settings, + logger=logging.getLogger("test"), + ) + + result = await use_case.execute_with_error_handling(progress_callback=progress) + + assert result["success"] is True + assert progress.events[0]["status"] == "started" + assert progress.events[-1]["status"] == "completed" + assert metadata_sync.calls + assert scanner._cache.resort_calls == 1 + + +async def test_bulk_metadata_refresh_reports_errors() -> None: + class FailingScanner(MockScanner): + async def get_cached_data(self, force_refresh: bool = False): + raise RuntimeError("boom") + + service = MockModelService(FailingScanner()) + metadata_sync = StubMetadataSync() + settings = StubSettings() + progress = ProgressCollector() + + use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=metadata_sync, + settings_service=settings, + logger=logging.getLogger("test"), + ) + + with pytest.raises(RuntimeError): + await use_case.execute_with_error_handling(progress_callback=progress) + + assert progress.events + assert progress.events[-1]["status"] == "error" + assert progress.events[-1]["error"] == "boom" + + +async def test_download_model_use_case_raises_validation_error() -> None: + coordinator = StubDownloadCoordinator(error="validation") + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + with pytest.raises(DownloadModelValidationError): + await use_case.execute({}) + + +async def test_download_model_use_case_raises_early_access() -> None: + coordinator = StubDownloadCoordinator(error="401") + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + with pytest.raises(DownloadModelEarlyAccessError): + await use_case.execute({"model_id": 1}) + + +async def test_download_model_use_case_returns_result() -> None: + coordinator = StubDownloadCoordinator() + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + result = await use_case.execute({"model_id": 1}) + + assert result["success"] is True + assert result["download_id"] == "abc123" From 66a3f3f59a0c80113b9ec25c6eb88381ba7b12db Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 22 Sep 2025 05:37:24 +0800 Subject: [PATCH 05/24] refactor(tests): enhance async test handling in pytest_pyfunc_call --- tests/conftest.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 818aeb9b..006a1adc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,8 +42,30 @@ sys.modules['nodes'] = nodes_mock def pytest_pyfunc_call(pyfuncitem): - if inspect.iscoroutinefunction(pyfuncitem.function): - asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs)) + """Allow bare async tests to run without pytest.mark.asyncio.""" + test_function = pyfuncitem.function + if inspect.iscoroutinefunction(test_function): + func = pyfuncitem.obj + signature = inspect.signature(func) + accepted_kwargs: Dict[str, Any] = {} + for name, parameter in signature.parameters.items(): + if parameter.kind is inspect.Parameter.VAR_POSITIONAL: + continue + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + accepted_kwargs = dict(pyfuncitem.funcargs) + break + if name in pyfuncitem.funcargs: + accepted_kwargs[name] = pyfuncitem.funcargs[name] + + original_policy = asyncio.get_event_loop_policy() + policy = pyfuncitem.funcargs.get("event_loop_policy") + if policy is not None and policy is not original_policy: + asyncio.set_event_loop_policy(policy) + try: + asyncio.run(func(**accepted_kwargs)) + finally: + if policy is not None and policy is not original_policy: + asyncio.set_event_loop_policy(original_policy) return True return None @@ -196,3 +218,5 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS @pytest.fixture def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner) + + From 1c4096f3d5836b53fd73e6d5a7c6678890e2902e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 22 Sep 2025 06:28:30 +0800 Subject: [PATCH 06/24] test(routes): add tests for service readiness and error handling in download model --- docs/architecture/model_routes.md | 189 ++++++++----------- tests/routes/test_base_model_routes_smoke.py | 71 +++++++ 2 files changed, 152 insertions(+), 108 deletions(-) diff --git a/docs/architecture/model_routes.md b/docs/architecture/model_routes.md index 00329299..564ba010 100644 --- a/docs/architecture/model_routes.md +++ b/docs/architecture/model_routes.md @@ -1,127 +1,100 @@ # Base model route architecture -The `BaseModelRoutes` controller centralizes HTTP endpoints that every model type -(LoRAs, checkpoints, embeddings, etc.) share. Each handler either forwards the -request to the injected service, delegates to a utility in -`ModelRouteUtils`, or orchestrates long‑running operations via helper services -such as the download or WebSocket managers. The table below lists every handler -exposed in `py/routes/base_model_routes.py`, the collaborators it leans on, and -any cache or WebSocket side effects implemented in -`py/utils/routes_common.py`. +The model routing stack now splits HTTP wiring, orchestration logic, and +business rules into discrete layers. The goal is to make it obvious where a +new collaborator should live and which contract it must honour. The diagram +below captures the end-to-end flow for a typical request: -## Contents +```mermaid +graph TD + subgraph HTTP + A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy] + end + subgraph Application + B --> C[ModelHandlerSet] + C --> D1[Handlers] + D1 --> E1[Use cases] + E1 --> F1[Services / scanners] + end + subgraph Side Effects + F1 --> G1[Cache & metadata] + F1 --> G2[Filesystem] + F1 --> G3[WebSocket state] + end +``` -- [Handler catalogue](#handler-catalogue) -- [Dependency map and contracts](#dependency-map-and-contracts) - - [Cache and metadata mutations](#cache-and-metadata-mutations) - - [Download and WebSocket flows](#download-and-websocket-flows) - - [Read-only queries](#read-only-queries) - - [Template rendering and initialization](#template-rendering-and-initialization) +Every box maps to a concrete module: -## Handler catalogue +| Layer | Module(s) | Responsibility | +| --- | --- | --- | +| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. | +| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. | +| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). | +| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. | +| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. | -The routes exposed by `BaseModelRoutes` combine HTTP wiring with a handful of -shared helper classes. Services surface filesystem and metadata operations, -`ModelRouteUtils` bundles cache-sensitive mutations, and `ws_manager` -coordinates fan-out to browser clients. The tables below expand the existing -catalogue into explicit dependency maps and invariants so refactors can reason -about the expectations each collaborator must uphold. +## Handler responsibilities & contracts -## Dependency map and contracts +`ModelHandlerSet` flattens the handler objects into the exact callables used by +the registrar. The table below highlights the separation of concerns within +the set and the invariants that must hold after each handler returns. -### Cache and metadata mutations - -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | +| Handler | Key endpoints | Collaborators | Contracts | | --- | --- | --- | --- | -| `/api/lm/{prefix}/delete` | `ModelRouteUtils.handle_delete_model()` | Removes files from disk, prunes `scanner._cache.raw_data`, awaits `scanner._cache.resort()`, calls `scanner._hash_index.remove_by_path()`. | Cache and hash index must no longer reference the deleted path; resort must complete before responding to keep pagination deterministic. | -| `/api/lm/{prefix}/exclude` | `ModelRouteUtils.handle_exclude_model()` | Mutates metadata records, `scanner._cache.raw_data`, `scanner._hash_index`, `scanner._tags_count`, and `scanner._excluded_models`. | Excluded models remain discoverable via exclusion list while being hidden from listings; tag counts stay balanced after removal. | -| `/api/lm/{prefix}/fetch-civitai` | `ModelRouteUtils.fetch_and_update_model()` | Reads `scanner._cache.raw_data`, writes metadata JSON through `MetadataManager`, syncs cache via `scanner.update_single_model_cache`. | Requires a cached SHA256 hash; cache entries must reflect merged metadata before formatted response is returned. | -| `/api/lm/{prefix}/fetch-all-civitai` | `ModelRouteUtils.fetch_and_update_model()`, `ws_manager.broadcast()` | Iterates over cache, updates metadata files and cache records, optionally awaits `scanner._cache.resort()`. | Progress broadcasts follow started → processing → completed; if any model name changes, cache resort must run once before completion broadcast. | -| `/api/lm/{prefix}/relink-civitai` | `ModelRouteUtils.handle_relink_civitai()` | Updates metadata on disk and resynchronizes the cache entry. | The new association must propagate to `scanner.update_single_model_cache` so duplicate resolution and listings reflect the change immediately. | -| `/api/lm/{prefix}/replace-preview` | `ModelRouteUtils.handle_replace_preview()` | Writes optimized preview file, persists metadata via `MetadataManager`, updates cache with `scanner.update_preview_in_cache()`. | Preview path stored in metadata and cache must match the normalized file system path; NSFW level integer is synchronized across metadata and cache. | -| `/api/lm/{prefix}/save-metadata` | `ModelRouteUtils.handle_save_metadata()` | Writes metadata JSON and ensures cache entry mirrors the latest content. | Metadata persistence must be atomic—cache data should match on-disk metadata before response emits success. | -| `/api/lm/{prefix}/add-tags` | `ModelRouteUtils.handle_add_tags()` | Updates metadata tags, increments `scanner._tags_count`, and patches cached item. | Tag frequency map remains in sync with cache and metadata after increments. | -| `/api/lm/{prefix}/rename` | `ModelRouteUtils.handle_rename_model()` | Renames files, metadata, previews; updates cache indices and hash mappings. | File moves succeed or rollback as a unit so cache state never points to a missing file; hash index entries track the new path. | -| `/api/lm/{prefix}/bulk-delete` | `ModelRouteUtils.handle_bulk_delete_models()` | Delegates to `scanner.bulk_delete_models()` to delete files, trim cache, resort, and drop hash index entries. | Every requested path is removed from cache and index; resort happens once after bulk deletion. | -| `/api/lm/{prefix}/verify-duplicates` | `ModelRouteUtils.handle_verify_duplicates()` | Recomputes hashes, updates metadata and cached entries if discrepancies found. | Hash metadata stored in cache must mirror recomputed values to guarantee future duplicate checks operate on current data. | -| `/api/lm/{prefix}/scan` | `service.scan_models()` | Rescans filesystem, rebuilding scanner cache. | Scanner replaces its cache atomically so subsequent requests observe a consistent snapshot. | -| `/api/lm/{prefix}/move_model` | `ModelMoveService.move_model()` | Moves files/directories and notifies scanner via service layer conventions. | Move operations respect filesystem invariants (target path exists, metadata follows file) and emit success/failure without leaving partial moves. | -| `/api/lm/{prefix}/move_models_bulk` | `ModelMoveService.move_models_bulk()` | Batch move behavior as above. | Aggregated result enumerates successes/failures while preserving per-model atomicity. | -| `/api/lm/{prefix}/auto-organize` (GET/POST) | `ModelFileService.auto_organize_models()`, `ws_manager.get_auto_organize_lock()`, `WebSocketProgressCallback` | Writes organized files, updates metadata, and streams progress snapshots. | Only one auto-organize job may run; lock must guard reentrancy and WebSocket updates must include latest progress payload consumed by polling route. | +| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. | +| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. | +| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelRouteUtils`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. | +| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. | +| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. | +| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. | +| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. | +| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. | -### Download and WebSocket flows +## Use case boundaries -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | +Each use case exposes a narrow asynchronous API that hides the underlying +services. Their error mapping is essential for predictable HTTP responses. + +| Use case | Entry point | Dependencies | Guarantees | | --- | --- | --- | --- | -| `/api/lm/download-model` (POST) & `/api/lm/download-model-get` (GET) | `ModelRouteUtils.handle_download_model()`, `ServiceRegistry.get_download_manager()` | Schedules downloads, registers `ws_manager.broadcast_download_progress()` callback that stores progress in `ws_manager._download_progress`. | Download IDs remain stable across POST/GET helpers; every progress callback persists a timestamped entry so `/download-progress` and WebSocket clients share consistent snapshots. | -| `/api/lm/cancel-download-get` | `ModelRouteUtils.handle_cancel_download()` | Signals download manager, prunes `ws_manager._download_progress`, and emits cancellation broadcast. | Cancel requests must tolerate missing IDs gracefully while ensuring cached progress is removed once cancellation succeeds. | -| `/api/lm/download-progress/{download_id}` | `ws_manager.get_download_progress()` | Reads cached progress dictionary. | Returns `404` when progress is absent; successful payload surfaces the numeric `progress` stored during broadcasts. | -| `/api/lm/{prefix}/fetch-all-civitai` | `ws_manager.broadcast()` | Broadcast loop described above. | Broadcast cadence cannot skip completion/error messages so clients know when to clear UI spinners. | -| `/api/lm/{prefix}/auto-organize-progress` | `ws_manager.get_auto_organize_progress()` | Reads cached progress snapshot. | Route returns cached payload verbatim; absence yields `404` to signal idle state. | +| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. | +| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. | +| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. | -### Read-only queries +## Maintaining legacy contracts -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | -| --- | --- | --- | --- | -| `/api/lm/{prefix}/list` | `service.get_paginated_data()`, `service.format_response()` | Reads service-managed pagination data. | Formatting must be applied to every item before response; pagination metadata echoes service result. | -| `/api/lm/{prefix}/top-tags` | `service.get_top_tags()` | Reads aggregated tag counts. | Limit parameter bounded to `[1, 100]`; response always wraps tags in `{success: True}` envelope. | -| `/api/lm/{prefix}/base-models` | `service.get_base_models()` | Reads service data. | Same limit handling as tags. | -| `/api/lm/{prefix}/roots` | `service.get_model_roots()` | Reads configured roots. | Always returns `{success: True, roots: [...]}`. | -| `/api/lm/{prefix}/folders` | `service.scanner.get_cached_data()` | Reads folder summaries from cache. | Cache access must tolerate initialization phases by surfacing errors via HTTP 500. | -| `/api/lm/{prefix}/folder-tree` | `service.get_folder_tree()` | Reads derived tree for requested root. | Rejects missing `model_root` with HTTP 400. | -| `/api/lm/{prefix}/unified-folder-tree` | `service.get_unified_folder_tree()` | Aggregated folder tree. | Returns `{success: True, tree: ...}` or 500 on error. | -| `/api/lm/{prefix}/find-duplicates` | `service.find_duplicate_hashes()`, `service.scanner.get_cached_data()`, `service.get_path_by_hash()` | Reads cache and hash index to format duplicates. | Only returns groups with more than one resolved model. | -| `/api/lm/{prefix}/find-filename-conflicts` | `service.find_duplicate_filenames()`, `service.scanner.get_cached_data()`, `service.scanner.get_hash_by_filename()` | Similar read-only assembly. | Includes resolved main index entry when available; empty `models` groups are omitted. | -| `/api/lm/{prefix}/get-notes` | `service.get_model_notes()` | Reads persisted notes. | Missing notes produce HTTP 404 with explicit error message. | -| `/api/lm/{prefix}/preview-url` | `service.get_model_preview_url()` | Resolves static URL. | Successful responses wrap URL in `{success: True}`; missing preview yields 404 error payload. | -| `/api/lm/{prefix}/civitai-url` | `service.get_model_civitai_url()` | Returns remote permalink info. | Response envelope matches preview pattern. | -| `/api/lm/{prefix}/metadata` | `service.get_model_metadata()` | Reads metadata JSON. | Responds with raw metadata dict or 500 on failure. | -| `/api/lm/{prefix}/model-description` | `service.get_model_description()` | Returns formatted description string. | Always JSON with success boolean. | -| `/api/lm/{prefix}/relative-paths` | `service.get_relative_paths()` | Resolves filesystem suggestions. | Maintains read-only contract. | -| `/api/lm/{prefix}/civitai/versions/{model_id}` | `get_default_metadata_provider()`, `service.has_hash()`, `service.get_path_by_hash()` | Reads remote API, cross-references cache. | Versions payload includes `existsLocally`/`localPath` only when hashes match local indices. | -| `/api/lm/{prefix}/civitai/model/version/{modelVersionId}` | `get_default_metadata_provider()` | Remote metadata lookup. | Errors propagate as JSON with `{success: False}` payload. | -| `/api/lm/{prefix}/civitai/model/hash/{hash}` | `get_default_metadata_provider()` | Remote metadata lookup. | Missing hashes return 404 with `{success: False}`. | +The refactor preserves the invariants called out in the previous architecture +notes. The most critical ones are reiterated here to emphasise the +collaboration points: -### Template rendering and initialization +1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are + channelled through `ModelManagementHandler`. The handler delegates to + `ModelRouteUtils` or `MetadataSyncService`, and the scanner cache is mutated + in-place before the handler returns. The accompanying tests assert that + `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after each + mutation. +2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new + asset, `MetadataSyncService` persists the JSON metadata, and + `scanner.update_preview_in_cache` mirrors the change. The handler returns + the static URL produced by `config.get_preview_static_url`, keeping browser + clients in lockstep with disk state. +3. **Download progress** – `DownloadCoordinator.schedule_download` generates the + download identifier, registers a WebSocket progress callback, and caches the + latest numeric progress via `WebSocketManager`. Both `download_model` + responses and `/download-progress/{id}` polling read from the same cache to + guarantee consistent progress reporting across transports. -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | -| --- | --- | --- | --- | -| `/{prefix}` | `handle_models_page` | Reads configuration via `settings`, sets locale with `server_i18n`, pulls cached folders through `service.scanner.get_cached_data()`, renders Jinja template. | Template rendering must tolerate scanner initialization by flagging `is_initializing`; i18n filter is attached exactly once per environment to avoid duplicate registration errors. | +## Extending the stack -### Contract sequences +To add a new shared route: -The following high-level sequences show how the collaborating services work -together for the most stateful operations: +1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name. +2. Implement the corresponding coroutine on one of the handlers inside + `ModelHandlerSet` (or introduce a new handler class when the concern does not + fit existing ones). +3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by + wiring services or use cases through the constructor parameters. -``` -delete_model request - → BaseModelRoutes.delete_model - → ModelRouteUtils.handle_delete_model - → filesystem delete + metadata cleanup - → scanner._cache.raw_data prune - → await scanner._cache.resort() - → scanner._hash_index.remove_by_path() -``` - -``` -replace_preview request - → BaseModelRoutes.replace_preview - → ModelRouteUtils.handle_replace_preview - → ExifUtils.optimize_image / config.get_preview_static_url - → MetadataManager.save_metadata - → scanner.update_preview_in_cache(model_path, preview_path, nsfw_level) -``` - -``` -download_model request - → BaseModelRoutes.download_model - → ModelRouteUtils.handle_download_model - → ServiceRegistry.get_download_manager().download_from_civitai(..., progress_callback) - → ws_manager.broadcast_download_progress(download_id, data) - → ws_manager._download_progress[download_id] updated with timestamp - → /api/lm/download-progress/{id} polls ws_manager.get_download_progress -``` - -These contracts complement the tables above: if any collaborator changes its -behavior, the invariants called out here must continue to hold for the routes -to remain predictable. +Model-specific routes should continue to be registered inside the subclass +implementation of `setup_specific_routes`, reusing the shared registrar where +possible. diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 25ebaabc..136bd0a8 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -67,12 +67,24 @@ def download_manager_stub(): class FakeDownloadManager: def __init__(self): self.calls = [] + self.error = None + self.cancelled = [] + self.active_downloads = {} async def download_from_civitai(self, **kwargs): self.calls.append(kwargs) + if self.error is not None: + raise self.error await kwargs["progress_callback"](42) return {"success": True, "path": "/tmp/model.safetensors"} + async def cancel_download(self, download_id): + self.cancelled.append(download_id) + return {"success": True, "download_id": download_id} + + async def get_active_downloads(self): + return self.active_downloads + stub = FakeDownloadManager() previous = ServiceRegistry._services.get("download_manager") asyncio.run(ServiceRegistry.register_service("download_manager", stub)) @@ -104,6 +116,21 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner): asyncio.run(scenario()) +def test_routes_return_service_not_ready_when_unattached(): + async def scenario(): + client = await create_test_client(None) + try: + response = await client.get("/api/lm/test-models/list") + payload = await response.json() + + assert response.status == 503 + assert payload == {"success": False, "error": "Service not ready"} + finally: + await client.close() + + asyncio.run(scenario()) + + def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path): model_path = tmp_path / "sample.safetensors" model_path.write_bytes(b"model") @@ -242,6 +269,50 @@ def test_download_model_requires_identifier(mock_service, download_manager_stub) asyncio.run(scenario()) +def test_download_model_maps_validation_errors(mock_service, download_manager_stub): + download_manager_stub.error = ValueError("Invalid relative path") + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_version_id": 123}, + ) + payload = await response.json() + + assert response.status == 400 + assert payload == {"success": False, "error": "Invalid relative path"} + assert ws_manager._download_progress == {} + finally: + await client.close() + + asyncio.run(scenario()) + + +def test_download_model_maps_early_access_errors(mock_service, download_manager_stub): + download_manager_stub.error = RuntimeError("401 early access") + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_id": 4}, + ) + payload = await response.json() + + assert response.status == 401 + assert payload == { + "success": False, + "error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.", + } + finally: + await client.close() + + asyncio.run(scenario()) + + def test_auto_organize_progress_returns_latest_snapshot(mock_service): async def scenario(): client = await create_test_client(mock_service) From 08baf884d3e4b050034f3690f006550d6f376037 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 08:28:30 +0800 Subject: [PATCH 07/24] refactor(routes): migrate lifecycle mutations to service --- docs/architecture/model_routes.md | 10 +- py/routes/base_model_routes.py | 20 ++ py/routes/handlers/model_handlers.py | 83 ++++++- py/services/model_lifecycle_service.py | 245 ++++++++++++++++++ py/services/model_scanner.py | 7 +- py/utils/routes_common.py | 327 +------------------------ 6 files changed, 352 insertions(+), 340 deletions(-) create mode 100644 py/services/model_lifecycle_service.py diff --git a/docs/architecture/model_routes.md b/docs/architecture/model_routes.md index 564ba010..a9fbf967 100644 --- a/docs/architecture/model_routes.md +++ b/docs/architecture/model_routes.md @@ -43,7 +43,7 @@ the set and the invariants that must hold after each handler returns. | --- | --- | --- | --- | | `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. | | `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. | -| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelRouteUtils`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. | +| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. | | `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. | | `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. | | `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. | @@ -69,10 +69,10 @@ collaboration points: 1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are channelled through `ModelManagementHandler`. The handler delegates to - `ModelRouteUtils` or `MetadataSyncService`, and the scanner cache is mutated - in-place before the handler returns. The accompanying tests assert that - `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after each - mutation. + `ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is + mutated in-place before the handler returns. The accompanying tests assert + that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after + each mutation. 2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new asset, `MetadataSyncService` persists the JSON metadata, and `scanner.update_preview_in_cache` mirrors the change. The handler returns diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 872dca8b..35415331 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -13,6 +13,7 @@ from ..services.downloader import get_downloader from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider from ..services.metadata_sync_service import MetadataSyncService from ..services.model_file_service import ModelFileService, ModelMoveService +from ..services.model_lifecycle_service import ModelLifecycleService from ..services.preview_asset_service import PreviewAssetService from ..services.server_i18n import server_i18n as default_server_i18n from ..services.service_registry import ServiceRegistry @@ -75,6 +76,7 @@ class BaseModelRoutes(ABC): self.model_file_service: ModelFileService | None = None self.model_move_service: ModelMoveService | None = None + self.model_lifecycle_service: ModelLifecycleService | None = None self.websocket_progress_callback = WebSocketProgressCallback() self.metadata_progress_callback = WebSocketBroadcastCallback() @@ -108,6 +110,12 @@ class BaseModelRoutes(ABC): self.model_type = service.model_type self.model_file_service = ModelFileService(service.scanner, service.model_type) self.model_move_service = ModelMoveService(service.scanner) + self.model_lifecycle_service = ModelLifecycleService( + scanner=service.scanner, + metadata_manager=MetadataManager, + metadata_loader=self._metadata_sync_service.load_local_metadata, + recipe_scanner_factory=ServiceRegistry.get_recipe_scanner, + ) self._handler_set = None self._handler_mapping = None @@ -139,6 +147,7 @@ class BaseModelRoutes(ABC): metadata_sync=self._metadata_sync_service, preview_service=self._preview_service, tag_update_service=self._tag_update_service, + lifecycle_service=self._ensure_lifecycle_service(), ) query = ModelQueryHandler(service=service, logger=logger) download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator) @@ -248,6 +257,17 @@ class BaseModelRoutes(ABC): self.model_move_service = ModelMoveService(service.scanner) return self.model_move_service + def _ensure_lifecycle_service(self) -> ModelLifecycleService: + if self.model_lifecycle_service is None: + service = self._ensure_service() + self.model_lifecycle_service = ModelLifecycleService( + scanner=service.scanner, + metadata_manager=MetadataManager, + metadata_loader=self._metadata_sync_service.load_local_metadata, + recipe_scanner_factory=ServiceRegistry.get_recipe_scanner, + ) + return self.model_lifecycle_service + def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]: async def proxy(request: web.Request) -> web.StreamResponse: try: diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index ba0628c6..a6fe4091 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -30,7 +30,6 @@ from ...services.use_cases import ( from ...services.websocket_manager import WebSocketManager from ...services.websocket_progress_callback import WebSocketProgressCallback from ...utils.file_utils import calculate_sha256 -from ...utils.routes_common import ModelRouteUtils class ModelPageView: @@ -192,18 +191,44 @@ class ModelManagementHandler: metadata_sync: MetadataSyncService, preview_service: PreviewAssetService, tag_update_service: TagUpdateService, + lifecycle_service, ) -> None: self._service = service self._logger = logger self._metadata_sync = metadata_sync self._preview_service = preview_service self._tag_update_service = tag_update_service + self._lifecycle_service = lifecycle_service async def delete_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_delete_model(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="Model path is required", status=400) + + result = await self._lifecycle_service.delete_model(file_path) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error deleting model: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def exclude_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_exclude_model(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="Model path is required", status=400) + + result = await self._lifecycle_service.exclude_model(file_path) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error excluding model: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def fetch_civitai(self, request: web.Request) -> web.Response: try: @@ -375,10 +400,58 @@ class ModelManagementHandler: return web.Response(text=str(exc), status=500) async def rename_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_rename_model(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + new_file_name = data.get("new_file_name") + + if not file_path or not new_file_name: + return web.json_response( + { + "success": False, + "error": "File path and new file name are required", + }, + status=400, + ) + + result = await self._lifecycle_service.rename_model( + file_path=file_path, new_file_name=new_file_name + ) + + return web.json_response( + { + **result, + "new_preview_path": config.get_preview_static_url( + result.get("new_preview_path") + ), + } + ) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error renaming model: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def bulk_delete_models(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner) + try: + data = await request.json() + file_paths = data.get("file_paths", []) + if not file_paths: + return web.json_response( + { + "success": False, + "error": "No file paths provided for deletion", + }, + status=400, + ) + + result = await self._lifecycle_service.bulk_delete_models(file_paths) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error in bulk delete: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def verify_duplicates(self, request: web.Request) -> web.Response: try: diff --git a/py/services/model_lifecycle_service.py b/py/services/model_lifecycle_service.py new file mode 100644 index 00000000..9aa87b04 --- /dev/null +++ b/py/services/model_lifecycle_service.py @@ -0,0 +1,245 @@ +"""Service routines for model lifecycle mutations.""" + +from __future__ import annotations + +import logging +import os +from typing import Awaitable, Callable, Dict, Iterable, List, Optional + +from ..services.service_registry import ServiceRegistry +from ..utils.constants import PREVIEW_EXTENSIONS + +logger = logging.getLogger(__name__) + + +async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]: + """Delete the primary model artefacts within ``target_dir``.""" + + patterns = [ + f"{file_name}.safetensors", + f"{file_name}.metadata.json", + ] + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{file_name}{ext}") + + deleted: List[str] = [] + main_file = patterns[0] + main_path = os.path.join(target_dir, main_file).replace(os.sep, "/") + + if os.path.exists(main_path): + os.remove(main_path) + deleted.append(main_path) + else: + logger.warning("Model file not found: %s", main_file) + + for pattern in patterns[1:]: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + try: + os.remove(path) + deleted.append(pattern) + except Exception as exc: # pragma: no cover - defensive path + logger.warning("Failed to delete %s: %s", pattern, exc) + + return deleted + + +class ModelLifecycleService: + """Co-ordinate destructive and mutating model operations.""" + + def __init__( + self, + *, + scanner, + metadata_manager, + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + recipe_scanner_factory: Callable[[], Awaitable] | None = None, + ) -> None: + self._scanner = scanner + self._metadata_manager = metadata_manager + self._metadata_loader = metadata_loader + self._recipe_scanner_factory = ( + recipe_scanner_factory or ServiceRegistry.get_recipe_scanner + ) + + async def delete_model(self, file_path: str) -> Dict[str, object]: + """Delete a model file and associated artefacts.""" + + if not file_path: + raise ValueError("Model path is required") + + target_dir = os.path.dirname(file_path) + file_name = os.path.splitext(os.path.basename(file_path))[0] + + deleted_files = await delete_model_artifacts(target_dir, file_name) + + cache = await self._scanner.get_cached_data() + cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path] + await cache.resort() + + if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index: + self._scanner._hash_index.remove_by_path(file_path) + + return {"success": True, "deleted_files": deleted_files} + + async def exclude_model(self, file_path: str) -> Dict[str, object]: + """Mark a model as excluded and prune cache references.""" + + if not file_path: + raise ValueError("Model path is required") + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + metadata = await self._metadata_loader(metadata_path) + metadata["exclude"] = True + + await self._metadata_manager.save_metadata(file_path, metadata) + + cache = await self._scanner.get_cached_data() + model_to_remove = next( + (item for item in cache.raw_data if item["file_path"] == file_path), + None, + ) + + if model_to_remove: + for tag in model_to_remove.get("tags", []): + if tag in getattr(self._scanner, "_tags_count", {}): + self._scanner._tags_count[tag] = max( + 0, self._scanner._tags_count[tag] - 1 + ) + if self._scanner._tags_count[tag] == 0: + del self._scanner._tags_count[tag] + + if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index: + self._scanner._hash_index.remove_by_path(file_path) + + cache.raw_data = [ + item for item in cache.raw_data if item["file_path"] != file_path + ] + await cache.resort() + + excluded = getattr(self._scanner, "_excluded_models", None) + if isinstance(excluded, list): + excluded.append(file_path) + + message = f"Model {os.path.basename(file_path)} excluded" + return {"success": True, "message": message} + + async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]: + """Delete a collection of models via the scanner bulk operation.""" + + file_paths = list(file_paths) + if not file_paths: + raise ValueError("No file paths provided for deletion") + + return await self._scanner.bulk_delete_models(file_paths) + + async def rename_model( + self, *, file_path: str, new_file_name: str + ) -> Dict[str, object]: + """Rename a model and its companion artefacts.""" + + if not file_path or not new_file_name: + raise ValueError("File path and new file name are required") + + invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"} + if any(char in new_file_name for char in invalid_chars): + raise ValueError("Invalid characters in file name") + + target_dir = os.path.dirname(file_path) + old_file_name = os.path.splitext(os.path.basename(file_path))[0] + new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace( + os.sep, "/" + ) + + if os.path.exists(new_file_path): + raise ValueError("A file with this name already exists") + + patterns = [ + f"{old_file_name}.safetensors", + f"{old_file_name}.metadata.json", + f"{old_file_name}.metadata.json.bak", + ] + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{old_file_name}{ext}") + + existing_files: List[tuple[str, str]] = [] + for pattern in patterns: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + existing_files.append((path, pattern)) + + metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") + metadata: Optional[Dict[str, object]] = None + hash_value: Optional[str] = None + + if os.path.exists(metadata_path): + metadata = await self._metadata_loader(metadata_path) + hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None + + renamed_files: List[str] = [] + new_metadata_path: Optional[str] = None + new_preview: Optional[str] = None + + for old_path, pattern in existing_files: + ext = self._get_multipart_ext(pattern) + new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace( + os.sep, "/" + ) + os.rename(old_path, new_path) + renamed_files.append(new_path) + + if ext == ".metadata.json": + new_metadata_path = new_path + + if metadata and new_metadata_path: + metadata["file_name"] = new_file_name + metadata["file_path"] = new_file_path + + if metadata.get("preview_url"): + old_preview = str(metadata["preview_url"]) + ext = self._get_multipart_ext(old_preview) + new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace( + os.sep, "/" + ) + metadata["preview_url"] = new_preview + + await self._metadata_manager.save_metadata(new_file_path, metadata) + + if metadata: + await self._scanner.update_single_model_cache( + file_path, new_file_path, metadata + ) + + if hash_value and getattr(self._scanner, "model_type", "") == "lora": + recipe_scanner = await self._recipe_scanner_factory() + if recipe_scanner: + try: + await recipe_scanner.update_lora_filename_by_hash( + hash_value, new_file_name + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error( + "Error updating recipe references for %s: %s", + file_path, + exc, + ) + + return { + "success": True, + "new_file_path": new_file_path, + "new_preview_path": new_preview, + "renamed_files": renamed_files, + "reload_required": False, + } + + @staticmethod + def _get_multipart_ext(filename: str) -> str: + """Return the extension for files with compound suffixes.""" + + parts = filename.split(".") + if len(parts) == 3: + return "." + ".".join(parts[-2:]) + if len(parts) >= 4: + return "." + ".".join(parts[-3:]) + return os.path.splitext(filename)[1] + diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index f0ae3177..51aa4507 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -13,6 +13,7 @@ from ..utils.metadata_manager import MetadataManager from .model_cache import ModelCache from .model_hash_index import ModelHashIndex from ..utils.constants import PREVIEW_EXTENSIONS +from .model_lifecycle_service import delete_model_artifacts from .service_registry import ServiceRegistry from .websocket_manager import ws_manager @@ -1040,10 +1041,8 @@ class ModelScanner: target_dir = os.path.dirname(file_path) file_name = os.path.splitext(os.path.basename(file_path))[0] - # Delete all associated files for the model - from ..utils.routes_common import ModelRouteUtils - deleted_files = await ModelRouteUtils.delete_model_files( - target_dir, + deleted_files = await delete_model_artifacts( + target_dir, file_name ) diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index b5f6af30..84e2fc9d 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -1,7 +1,7 @@ import os import json import logging -from typing import Dict, List, Callable, Awaitable +from typing import Dict, Callable, Awaitable from aiohttp import web from datetime import datetime @@ -284,104 +284,6 @@ class ModelRouteUtils: ] return {k: data[k] for k in fields if k in data} - @staticmethod - async def delete_model_files(target_dir: str, file_name: str) -> List[str]: - """Delete model and associated files - - Args: - target_dir: Directory containing the model files - file_name: Base name of the model file without extension - - Returns: - List of deleted file paths - """ - patterns = [ - f"{file_name}.safetensors", # Required - f"{file_name}.metadata.json", - ] - - # Add all preview file extensions - for ext in PREVIEW_EXTENSIONS: - patterns.append(f"{file_name}{ext}") - - deleted = [] - main_file = patterns[0] - main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') - - if os.path.exists(main_path): - # Delete file - os.remove(main_path) - deleted.append(main_path) - else: - logger.warning(f"Model file not found: {main_file}") - - # Delete optional files - for pattern in patterns[1:]: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - try: - os.remove(path) - deleted.append(pattern) - except Exception as e: - logger.warning(f"Failed to delete {pattern}: {e}") - - return deleted - - @staticmethod - def get_multipart_ext(filename): - """Get extension that may have multiple parts like .metadata.json or .metadata.json.bak""" - parts = filename.split(".") - if len(parts) == 3: # If contains 2-part extension - return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json" - elif len(parts) >= 4: # If contains 3-part or more extensions - return "." + ".".join(parts[-3:]) # Take the last three parts, like ".metadata.json.bak" - return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" - - # New common endpoint handlers - - @staticmethod - async def handle_delete_model(request: web.Request, scanner) -> web.Response: - """Handle model deletion request - - Args: - request: The aiohttp request - scanner: The model scanner instance with cache management methods - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='Model path is required', status=400) - - target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - - deleted_files = await ModelRouteUtils.delete_model_files( - target_dir, - file_name - ) - - # Remove from cache - cache = await scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] - await cache.resort() - - # Update hash index if available - if hasattr(scanner, '_hash_index') and scanner._hash_index: - scanner._hash_index.remove_by_path(file_path) - - return web.json_response({ - 'success': True, - 'deleted_files': deleted_files - }) - - except Exception as e: - logger.error(f"Error deleting model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - @staticmethod async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response: """Handle CivitAI metadata fetch request @@ -544,64 +446,6 @@ class ModelRouteUtils: logger.error(f"Error replacing preview: {e}", exc_info=True) return web.Response(text=str(e), status=500) - @staticmethod - async def handle_exclude_model(request: web.Request, scanner) -> web.Response: - """Handle model exclusion request - - Args: - request: The aiohttp request - scanner: The model scanner instance with cache management methods - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='Model path is required', status=400) - - # Update metadata to mark as excluded - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - metadata['exclude'] = True - - # Save updated metadata - await MetadataManager.save_metadata(file_path, metadata) - - # Update cache - cache = await scanner.get_cached_data() - - # Find and remove model from cache - model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) - if model_to_remove: - # Update tags count - for tag in model_to_remove.get('tags', []): - if tag in scanner._tags_count: - scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) - if scanner._tags_count[tag] == 0: - del scanner._tags_count[tag] - - # Remove from hash index if available - if hasattr(scanner, '_hash_index') and scanner._hash_index: - scanner._hash_index.remove_by_path(file_path) - - # Remove from cache data - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] - await cache.resort() - - # Add to excluded models list - scanner._excluded_models.append(file_path) - - return web.json_response({ - 'success': True, - 'message': f"Model {os.path.basename(file_path)} excluded" - }) - - except Exception as e: - logger.error(f"Error excluding model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - @staticmethod async def handle_download_model(request: web.Request) -> web.Response: """Handle model download request""" @@ -755,44 +599,6 @@ class ModelRouteUtils: 'error': str(e) }, status=500) - @staticmethod - async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response: - """Handle bulk deletion of models - - Args: - request: The aiohttp request - scanner: The model scanner instance with cache management methods - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_paths = data.get('file_paths', []) - - if not file_paths: - return web.json_response({ - 'success': False, - 'error': 'No file paths provided for deletion' - }, status=400) - - # Use the scanner's bulk delete method to handle all cache and file operations - result = await scanner.bulk_delete_models(file_paths) - - return web.json_response({ - 'success': result.get('success', False), - 'total_deleted': result.get('total_deleted', 0), - 'total_attempted': result.get('total_attempted', len(file_paths)), - 'results': result.get('results', []) - }) - - except Exception as e: - logger.error(f"Error in bulk delete: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - @staticmethod async def handle_relink_civitai(request: web.Request, scanner) -> web.Response: """Handle CivitAI metadata re-linking request by model ID and/or version ID @@ -948,137 +754,6 @@ class ModelRouteUtils: 'error': str(e) }, status=500) - @staticmethod - async def handle_rename_model(request: web.Request, scanner) -> web.Response: - """Handle renaming a model file and its associated files - - Args: - request: The aiohttp request - scanner: The model scanner instance - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_path = data.get('file_path') - new_file_name = data.get('new_file_name') - - if not file_path or not new_file_name: - return web.json_response({ - 'success': False, - 'error': 'File path and new file name are required' - }, status=400) - - # Validate the new file name (no path separators or invalid characters) - invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|'] - if any(char in new_file_name for char in invalid_chars): - return web.json_response({ - 'success': False, - 'error': 'Invalid characters in file name' - }, status=400) - - # Get the directory and current file name - target_dir = os.path.dirname(file_path) - old_file_name = os.path.splitext(os.path.basename(file_path))[0] - - # Check if the target file already exists - new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(os.sep, '/') - if os.path.exists(new_file_path): - return web.json_response({ - 'success': False, - 'error': 'A file with this name already exists' - }, status=400) - - # Define the patterns for associated files - patterns = [ - f"{old_file_name}.safetensors", # Required - f"{old_file_name}.metadata.json", - f"{old_file_name}.metadata.json.bak", - ] - - # Add all preview file extensions - for ext in PREVIEW_EXTENSIONS: - patterns.append(f"{old_file_name}{ext}") - - # Find all matching files - existing_files = [] - for pattern in patterns: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - existing_files.append((path, pattern)) - - # Get the hash from the main file to update hash index - hash_value = None - metadata = None - metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") - - if os.path.exists(metadata_path): - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - hash_value = metadata.get('sha256') - logger.info(f"hash_value: {hash_value}, metadata_path: {metadata_path}, metadata: {metadata}") - # Rename all files - renamed_files = [] - new_metadata_path = None - new_preview = None - - for old_path, pattern in existing_files: - # Get the file extension like .safetensors or .metadata.json - ext = ModelRouteUtils.get_multipart_ext(pattern) - - # Create the new path - new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') - - # Rename the file - os.rename(old_path, new_path) - renamed_files.append(new_path) - - # Keep track of metadata path for later update - if ext == '.metadata.json': - new_metadata_path = new_path - - # Update the metadata file with new file name and paths - if new_metadata_path and metadata: - # Update file_name, file_path and preview_url in metadata - metadata['file_name'] = new_file_name - metadata['file_path'] = new_file_path - - # Update preview_url if it exists - if 'preview_url' in metadata and metadata['preview_url']: - old_preview = metadata['preview_url'] - ext = ModelRouteUtils.get_multipart_ext(old_preview) - new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') - metadata['preview_url'] = new_preview - - # Save updated metadata - await MetadataManager.save_metadata(new_file_path, metadata) - - # Update the scanner cache - if metadata: - await scanner.update_single_model_cache(file_path, new_file_path, metadata) - - # Update recipe files and cache if hash is available and recipe_scanner exists - if hash_value and hasattr(scanner, 'update_lora_filename_by_hash'): - recipe_scanner = await ServiceRegistry.get_recipe_scanner() - if recipe_scanner: - recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name) - logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed model") - - return web.json_response({ - 'success': True, - 'new_file_path': new_file_path, - 'new_preview_path': config.get_preview_static_url(new_preview), - 'renamed_files': renamed_files, - 'reload_required': False - }) - - except Exception as e: - logger.error(f"Error renaming model: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - @staticmethod async def handle_save_metadata(request: web.Request, scanner) -> web.Response: """Handle saving metadata updates From c3b9c73541b96c416d9e9d8dfc7e38704b6d2b54 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 22 Sep 2025 09:09:40 +0800 Subject: [PATCH 08/24] refactor: remove ModelRouteUtils usage and implement filtering directly in services --- py/routes/base_model_routes.py | 5 ----- py/services/base_model_service.py | 15 +++++++++++++-- py/services/checkpoint_service.py | 5 ++--- py/services/embedding_service.py | 5 ++--- py/services/lora_service.py | 3 +-- py/utils/example_images_metadata.py | 1 - py/utils/routes_common.py | 2 +- tests/conftest.py | 4 ++-- 8 files changed, 21 insertions(+), 19 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 35415331..84b9f43f 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -31,7 +31,6 @@ from ..services.websocket_progress_callback import ( ) from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager -from ..utils.routes_common import ModelRouteUtils from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar from .handlers.model_handlers import ( ModelAutoOrganizeHandler, @@ -236,10 +235,6 @@ class BaseModelRoutes(ABC): """Expose handlers for subclasses or tests.""" return self._ensure_handler_mapping()[name] - @property - def utils(self) -> ModelRouteUtils: # pragma: no cover - compatibility shim - return ModelRouteUtils - def _ensure_service(self): if self.service is None: raise RuntimeError("Model service has not been attached") diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 0b4aaf99..2c2c0ad8 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -4,7 +4,6 @@ import logging import os from ..utils.models import BaseModelMetadata -from ..utils.routes_common import ModelRouteUtils from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from .settings_manager import settings as default_settings @@ -197,6 +196,18 @@ class BaseModelService(ABC): """Get model root directories""" return self.scanner.get_model_roots() + def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict: + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = ["id", "modelId", "name", "trainedWords"] if minimal else [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images", "customImages", "creator" + ] + return {k: data[k] for k in fields if k in data} + async def get_folder_tree(self, model_root: str) -> Dict: """Get hierarchical folder tree for a specific model root""" cache = await self.scanner.get_cached_data() @@ -307,7 +318,7 @@ class BaseModelService(ABC): for model in cache.raw_data: if model.get('file_path') == file_path: - return ModelRouteUtils.filter_civitai_data(model.get("civitai", {})) + return self.filter_civitai_data(model.get("civitai", {})) return None diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index ef3dc4a8..2f7b8a96 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -1,11 +1,10 @@ import os import logging -from typing import Dict, List, Optional +from typing import Dict from .base_model_service import BaseModelService from ..utils.models import CheckpointMetadata from ..config import config -from ..utils.routes_common import ModelRouteUtils logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class CheckpointService(BaseModelService): "notes": checkpoint_data.get("notes", ""), "model_type": checkpoint_data.get("model_type", "checkpoint"), "favorite": checkpoint_data.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) } def find_duplicate_hashes(self) -> Dict: diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index bab067d9..46396fc5 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -1,11 +1,10 @@ import os import logging -from typing import Dict, List, Optional +from typing import Dict from .base_model_service import BaseModelService from ..utils.models import EmbeddingMetadata from ..config import config -from ..utils.routes_common import ModelRouteUtils logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class EmbeddingService(BaseModelService): "notes": embedding_data.get("notes", ""), "model_type": embedding_data.get("model_type", "embedding"), "favorite": embedding_data.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) } def find_duplicate_hashes(self) -> Dict: diff --git a/py/services/lora_service.py b/py/services/lora_service.py index d1e522a3..551c4d3c 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional from .base_model_service import BaseModelService from ..utils.models import LoraMetadata from ..config import config -from ..utils.routes_common import ModelRouteUtils logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class LoraService(BaseModelService): "usage_tips": lora_data.get("usage_tips", ""), "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) } async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 66db05a3..8820b49b 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -69,7 +69,6 @@ class MetadataUpdater: # Track that we're refreshing this model download_progress['refreshed_models'].add(model_hash) - # Use ModelRouteUtils to refresh metadata async def update_cache_func(old_path, new_path, metadata): return await scanner.update_single_model_cache(old_path, new_path, metadata) diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 84e2fc9d..642bfcad 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -18,7 +18,7 @@ from ..services.settings_manager import settings logger = logging.getLogger(__name__) - +# TODO: retire this class class ModelRouteUtils: """Shared utilities for model routes (LoRAs, Checkpoints, etc.)""" diff --git a/tests/conftest.py b/tests/conftest.py index 006a1adc..58263c8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,7 +81,7 @@ class MockHashIndex: class MockCache: - """Cache object with the attributes consumed by ``ModelRouteUtils``.""" + """Cache object with the attributes.""" def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None): self.raw_data: List[Dict[str, Any]] = list(items or []) @@ -89,7 +89,7 @@ class MockCache: async def resort(self) -> None: self.resort_calls += 1 - # ``ModelRouteUtils`` expects the coroutine interface but does not + # expects the coroutine interface but does not # rely on the return value. From b92e7aa446392eaa95270ab37c5eb18517384c29 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 12:15:12 +0800 Subject: [PATCH 09/24] chore(routes): dedupe os import --- py/routes/base_recipe_routes.py | 109 +++++++++++++++++ py/routes/recipe_route_registrar.py | 64 ++++++++++ py/routes/recipe_routes.py | 181 +++++++++------------------- 3 files changed, 227 insertions(+), 127 deletions(-) create mode 100644 py/routes/base_recipe_routes.py create mode 100644 py/routes/recipe_route_registrar.py diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py new file mode 100644 index 00000000..e2b726da --- /dev/null +++ b/py/routes/base_recipe_routes.py @@ -0,0 +1,109 @@ +"""Base infrastructure shared across recipe routes.""" +from __future__ import annotations + +import logging +from typing import Callable, Mapping + +import jinja2 +from aiohttp import web + +from ..config import config +from ..services.server_i18n import server_i18n +from ..services.service_registry import ServiceRegistry +from ..services.settings_manager import settings +from .recipe_route_registrar import ROUTE_DEFINITIONS + +logger = logging.getLogger(__name__) + + +class BaseRecipeRoutes: + """Common dependency and startup wiring for recipe routes.""" + + _HANDLER_NAMES: tuple[str, ...] = tuple( + definition.handler_name for definition in ROUTE_DEFINITIONS + ) + + def __init__(self) -> None: + self.recipe_scanner = None + self.lora_scanner = None + self.civitai_client = None + self.settings = settings + self.server_i18n = server_i18n + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True, + ) + + self._i18n_registered = False + self._startup_hooks_registered = False + self._handler_mapping: dict[str, Callable] | None = None + + async def attach_dependencies(self, app: web.Application | None = None) -> None: + """Resolve shared services from the registry.""" + + await self._ensure_services() + self._ensure_i18n_filter() + + async def ensure_dependencies_ready(self) -> None: + """Ensure dependencies are available for request handlers.""" + + if self.recipe_scanner is None or self.civitai_client is None: + await self.attach_dependencies() + + def register_startup_hooks(self, app: web.Application) -> None: + """Register startup hooks once for dependency wiring.""" + + if self._startup_hooks_registered: + return + + app.on_startup.append(self.attach_dependencies) + app.on_startup.append(self.prewarm_cache) + self._startup_hooks_registered = True + + async def prewarm_cache(self, app: web.Application | None = None) -> None: + """Pre-load recipe and LoRA caches on startup.""" + + try: + await self.attach_dependencies(app) + + if self.lora_scanner is not None: + await self.lora_scanner.get_cached_data() + hash_index = getattr(self.lora_scanner, "_hash_index", None) + if hash_index is not None and hasattr(hash_index, "_hash_to_path"): + _ = len(hash_index._hash_to_path) + + if self.recipe_scanner is not None: + await self.recipe_scanner.get_cached_data(force_refresh=True) + except Exception as exc: + logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True) + + def to_route_mapping(self) -> Mapping[str, Callable]: + """Return a mapping of handler name to coroutine for registrar binding.""" + + if self._handler_mapping is None: + owner = self.get_handler_owner() + self._handler_mapping = { + name: getattr(owner, name) for name in self._HANDLER_NAMES + } + return self._handler_mapping + + # Internal helpers ------------------------------------------------- + + async def _ensure_services(self) -> None: + if self.recipe_scanner is None: + self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None) + + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + + def _ensure_i18n_filter(self) -> None: + if not self._i18n_registered: + self.template_env.filters["t"] = self.server_i18n.create_template_filter() + self._i18n_registered = True + + def get_handler_owner(self): + """Return the object supplying bound handler coroutines.""" + + return self + diff --git a/py/routes/recipe_route_registrar.py b/py/routes/recipe_route_registrar.py new file mode 100644 index 00000000..471edf19 --- /dev/null +++ b/py/routes/recipe_route_registrar.py @@ -0,0 +1,64 @@ +"""Route registrar for recipe endpoints.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Mapping + +from aiohttp import web + + +@dataclass(frozen=True) +class RouteDefinition: + """Declarative definition for a recipe HTTP route.""" + + method: str + path: str + handler_name: str + + +ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( + RouteDefinition("GET", "/loras/recipes", "render_page"), + RouteDefinition("GET", "/api/lm/recipes", "list_recipes"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"), + RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"), + RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"), + RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"), + RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"), + RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"), + RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"), + RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"), + RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"), + RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"), + RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"), + RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"), + RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"), + RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"), +) + + +class RecipeRouteRegistrar: + """Bind declarative recipe definitions to an aiohttp router.""" + + _METHOD_MAP = { + "GET": "add_get", + "POST": "add_post", + "PUT": "add_put", + "DELETE": "add_delete", + } + + def __init__(self, app: web.Application) -> None: + self._app = app + + def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None: + for definition in ROUTE_DEFINITIONS: + handler = handler_lookup[definition.handler_name] + self._bind_route(definition.method, definition.path, handler) + + def _bind_route(self, method: str, path: str, handler: Callable) -> None: + add_method_name = self._METHOD_MAP[method.upper()] + add_method = getattr(self._app.router, add_method_name) + add_method(path, handler) + diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 21214d99..dcd0751a 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,7 +1,6 @@ import os import time import base64 -import jinja2 import numpy as np from PIL import Image import io @@ -12,20 +11,18 @@ import tempfile import json import asyncio import sys + +from .base_recipe_routes import BaseRecipeRoutes +from .recipe_route_registrar import RecipeRouteRegistrar from ..utils.exif_utils import ExifUtils from ..recipes import RecipeParserFactory from ..utils.constants import CARD_PREVIEW_WIDTH - -from ..services.settings_manager import settings -from ..services.server_i18n import server_i18n from ..config import config +from ..services.downloader import get_downloader # Check if running in standalone mode standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" -from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import -from ..services.downloader import get_downloader - # Only import MetadataRegistry in non-standalone mode if not standalone_mode: # Import metadata_collector functions and classes conditionally @@ -35,111 +32,35 @@ if not standalone_mode: logger = logging.getLogger(__name__) -class RecipeRoutes: - """API route handlers for Recipe management""" - def __init__(self): - # Initialize service references as None, will be set during async init - self.recipe_scanner = None - self.civitai_client = None - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) - - # Pre-warm the cache - self._init_cache_task = None - - async def init_services(self): - """Initialize services from ServiceRegistry""" - self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() - self.civitai_client = await ServiceRegistry.get_civitai_client() +class RecipeRoutes(BaseRecipeRoutes): + """API route handlers for Recipe management.""" @classmethod def setup_routes(cls, app: web.Application): - """Register API routes""" + """Register API routes using the declarative registrar.""" + routes = cls() - app.router.add_get('/loras/recipes', routes.handle_recipes_page) + registrar = RecipeRouteRegistrar(app) + registrar.register_routes(routes.to_route_mapping()) + routes.register_startup_hooks(app) - app.router.add_get('/api/lm/recipes', routes.get_recipes) - app.router.add_get('/api/lm/recipe/{recipe_id}', routes.get_recipe_detail) - app.router.add_post('/api/lm/recipes/analyze-image', routes.analyze_recipe_image) - app.router.add_post('/api/lm/recipes/analyze-local-image', routes.analyze_local_image) - app.router.add_post('/api/lm/recipes/save', routes.save_recipe) - app.router.add_delete('/api/lm/recipe/{recipe_id}', routes.delete_recipe) - - # Add new filter-related endpoints - app.router.add_get('/api/lm/recipes/top-tags', routes.get_top_tags) - app.router.add_get('/api/lm/recipes/base-models', routes.get_base_models) - - # Add new sharing endpoints - app.router.add_get('/api/lm/recipe/{recipe_id}/share', routes.share_recipe) - app.router.add_get('/api/lm/recipe/{recipe_id}/share/download', routes.download_shared_recipe) - - # Add new endpoint for getting recipe syntax - app.router.add_get('/api/lm/recipe/{recipe_id}/syntax', routes.get_recipe_syntax) - - # Add new endpoint for updating recipe metadata (name, tags and source_path) - app.router.add_put('/api/lm/recipe/{recipe_id}/update', routes.update_recipe) - - # Add new endpoint for reconnecting deleted LoRAs - app.router.add_post('/api/lm/recipe/lora/reconnect', routes.reconnect_lora) - - # Add new endpoint for finding duplicate recipes - app.router.add_get('/api/lm/recipes/find-duplicates', routes.find_duplicates) - - # Add new endpoint for bulk deletion of recipes - app.router.add_post('/api/lm/recipes/bulk-delete', routes.bulk_delete) - - # Start cache initialization - app.on_startup.append(routes._init_cache) - - app.router.add_post('/api/lm/recipes/save-from-widget', routes.save_recipe_from_widget) - - # Add route to get recipes for a specific Lora - app.router.add_get('/api/lm/recipes/for-lora', routes.get_recipes_for_lora) - - # Add new endpoint for scanning and rebuilding the recipe cache - app.router.add_get('/api/lm/recipes/scan', routes.scan_recipes) - - async def _init_cache(self, app): - """Initialize cache on startup""" + async def render_page(self, request: web.Request) -> web.Response: + """Handle GET /loras/recipes request.""" try: - # Initialize services first - await self.init_services() - - # Now that services are initialized, get the lora scanner - lora_scanner = self.recipe_scanner._lora_scanner - - # Get lora cache to ensure it's initialized - lora_cache = await lora_scanner.get_cached_data() - - # Verify hash index is built - if hasattr(lora_scanner, '_hash_index'): - hash_index_size = len(lora_scanner._hash_index._hash_to_path) if hasattr(lora_scanner._hash_index, '_hash_to_path') else 0 - - # Now that lora scanner is initialized, initialize recipe cache - await self.recipe_scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True) + await self.ensure_dependencies_ready() - async def handle_recipes_page(self, request: web.Request) -> web.Response: - """Handle GET /loras/recipes request""" - try: - # Ensure services are initialized - await self.init_services() - # 获取用户语言设置 - user_language = settings.get('language', 'en') - + user_language = self.settings.get('language', 'en') + # 设置服务端i18n语言 - server_i18n.set_locale(user_language) - + self.server_i18n.set_locale(user_language) + # 为模板环境添加i18n过滤器 if not hasattr(self.template_env, '_i18n_filter_added'): - self.template_env.filters['t'] = server_i18n.create_template_filter() + self._ensure_i18n_filter() self.template_env._i18n_filter_added = True - + # Skip initialization check and directly try to get cached data try: # Recipe scanner will initialize cache if needed @@ -148,10 +69,10 @@ class RecipeRoutes: rendered = template.render( recipes=[], # Frontend will load recipes via API is_initializing=False, - settings=settings, + settings=self.settings, request=request, # 添加服务端翻译函数 - t=server_i18n.get_translation, + t=self.server_i18n.get_translation, ) except Exception as cache_error: logger.error(f"Error loading recipe cache data: {cache_error}") @@ -159,10 +80,10 @@ class RecipeRoutes: template = self.template_env.get_template('recipes.html') rendered = template.render( is_initializing=True, - settings=settings, + settings=self.settings, request=request, # 添加服务端翻译函数 - t=server_i18n.get_translation, + t=self.server_i18n.get_translation, ) logger.info("Recipe cache error, returning initialization page") @@ -178,11 +99,11 @@ class RecipeRoutes: status=500 ) - async def get_recipes(self, request: web.Request) -> web.Response: - """API endpoint for getting paginated recipes""" + async def list_recipes(self, request: web.Request) -> web.Response: + """API endpoint for getting paginated recipes.""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Get query parameters with defaults page = int(request.query.get('page', '1')) @@ -250,11 +171,11 @@ class RecipeRoutes: logger.error(f"Error retrieving recipes: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) - async def get_recipe_detail(self, request: web.Request) -> web.Response: - """Get detailed information about a specific recipe""" + async def get_recipe(self, request: web.Request) -> web.Response: + """Get detailed information about a specific recipe.""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() recipe_id = request.match_info['recipe_id'] @@ -305,12 +226,12 @@ class RecipeRoutes: from datetime import datetime return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') - async def analyze_recipe_image(self, request: web.Request) -> web.Response: - """Analyze an uploaded image or URL for recipe metadata""" + async def analyze_uploaded_image(self, request: web.Request) -> web.Response: + """Analyze an uploaded image or URL for recipe metadata.""" temp_path = None try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Check if request contains multipart data (image) or JSON data (url) content_type = request.headers.get('Content-Type', '') @@ -480,7 +401,7 @@ class RecipeRoutes: """Analyze a local image file for recipe metadata""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Get JSON data from request data = await request.json() @@ -573,7 +494,7 @@ class RecipeRoutes: """Save a recipe to the recipes folder""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() reader = await request.multipart() @@ -779,7 +700,7 @@ class RecipeRoutes: """Delete a recipe by ID""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() recipe_id = request.match_info['recipe_id'] @@ -829,7 +750,7 @@ class RecipeRoutes: """Get top tags used in recipes""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Get limit parameter with default limit = int(request.query.get('limit', '20')) @@ -864,7 +785,7 @@ class RecipeRoutes: """Get base models used in recipes""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Get all recipes from cache cache = await self.recipe_scanner.get_cached_data() @@ -895,7 +816,7 @@ class RecipeRoutes: """Process a recipe image for sharing by adding metadata to EXIF""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() recipe_id = request.match_info['recipe_id'] @@ -957,7 +878,7 @@ class RecipeRoutes: """Serve a processed recipe image for download""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() recipe_id = request.match_info['recipe_id'] @@ -1016,7 +937,7 @@ class RecipeRoutes: """Save a recipe from the LoRAs widget""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Get metadata using the metadata collector instead of workflow parsing raw_metadata = get_metadata() @@ -1216,7 +1137,7 @@ class RecipeRoutes: """Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() recipe_id = request.match_info['recipe_id'] @@ -1299,7 +1220,7 @@ class RecipeRoutes: """Update recipe metadata (name and tags)""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() recipe_id = request.match_info['recipe_id'] data = await request.json() @@ -1329,7 +1250,7 @@ class RecipeRoutes: """Reconnect a deleted LoRA in a recipe to a local LoRA file""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Parse request data data = await request.json() @@ -1438,7 +1359,7 @@ class RecipeRoutes: """Get recipes that use a specific Lora""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() lora_hash = request.query.get('hash') @@ -1487,7 +1408,7 @@ class RecipeRoutes: """API endpoint for scanning and rebuilding the recipe cache""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Force refresh the recipe cache logger.info("Manually triggering recipe cache rebuild") @@ -1508,7 +1429,7 @@ class RecipeRoutes: """Find all duplicate recipes based on fingerprints""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Get all duplicate recipes duplicate_groups = await self.recipe_scanner.find_all_duplicate_recipes() @@ -1566,7 +1487,7 @@ class RecipeRoutes: """Delete multiple recipes by ID""" try: # Ensure services are initialized - await self.init_services() + await self.ensure_dependencies_ready() # Parse request data data = await request.json() @@ -1650,3 +1571,9 @@ class RecipeRoutes: 'success': False, 'error': str(e) }, status=500) + + # Legacy method aliases retained for compatibility with existing imports. + handle_recipes_page = render_page + get_recipes = list_recipes + get_recipe_detail = get_recipe + analyze_recipe_image = analyze_uploaded_image From 3220cfb79c1a6fb01d4a7131339c10b565d1f85a Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 12:41:37 +0800 Subject: [PATCH 10/24] test(recipe-routes): add scaffolding baseline --- docs/architecture/recipe_routes.md | 50 ++++ py/routes/recipe_routes.py | 5 - tests/routes/test_recipe_route_scaffolding.py | 229 ++++++++++++++++++ 3 files changed, 279 insertions(+), 5 deletions(-) create mode 100644 docs/architecture/recipe_routes.md create mode 100644 tests/routes/test_recipe_route_scaffolding.py diff --git a/docs/architecture/recipe_routes.md b/docs/architecture/recipe_routes.md new file mode 100644 index 00000000..28684fad --- /dev/null +++ b/docs/architecture/recipe_routes.md @@ -0,0 +1,50 @@ +# Recipe route scaffolding + +The recipe HTTP stack is being migrated to mirror the shared model routing +architecture. The first phase extracts the registrar/controller scaffolding so +future handler sets can plug into a stable surface area. The stack now mirrors +the same separation of concerns described in +`docs/architecture/model_routes.md`: + +```mermaid +graph TD + subgraph HTTP + A[RecipeRouteRegistrar] -->|binds| B[BaseRecipeRoutes handler owner] + end + subgraph Application + B --> C[Recipe handler set] + C --> D[Async handlers] + D --> E[Services / scanners] + end +``` + +## Responsibilities + +| Layer | Module(s) | Responsibility | +| --- | --- | --- | +| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper that binds them to an `aiohttp` application. | +| Base controller | `py/routes/base_recipe_routes.py` | Lazily resolves shared services, registers the server-side i18n filter exactly once, pre-warms caches on startup, and exposes a `{handler_name: coroutine}` mapping used by the registrar. | +| Handler set (upcoming) | `py/routes/handlers/recipe_handlers.py` (planned) | Will group HTTP handlers by concern (page rendering, listings, mutations, queries, sharing) and surface them to `BaseRecipeRoutes.get_handler_owner()`. | + +`RecipeRoutes` subclasses the base controller to keep compatibility with the +existing monolithic handlers. Once the handler set is extracted the subclass +will simply provide the concrete owner returned by `get_handler_owner()`. + +## High-level test baseline + +The new smoke suite in `tests/routes/test_recipe_route_scaffolding.py` +guarantees the registrar/controller contract remains intact: + +* `BaseRecipeRoutes.attach_dependencies` resolves registry services only once + and protects the i18n filter from duplicate registration. +* Startup hooks are appended exactly once so cache pre-warming and dependency + resolution run during application boot. +* `BaseRecipeRoutes.to_route_mapping()` uses the handler owner as the source of + callables, enabling the upcoming handler set without touching the registrar. +* `RecipeRouteRegistrar` binds every declarative route to the aiohttp router. +* `RecipeRoutes.setup_routes` wires the registrar and startup hooks together so + future refactors can swap in the handler set without editing callers. + +These guardrails mirror the expectations in the model route architecture and +provide confidence that future refactors can focus on handlers and use cases +without breaking HTTP wiring. diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index dcd0751a..55208552 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -56,11 +56,6 @@ class RecipeRoutes(BaseRecipeRoutes): # 设置服务端i18n语言 self.server_i18n.set_locale(user_language) - # 为模板环境添加i18n过滤器 - if not hasattr(self.template_env, '_i18n_filter_added'): - self._ensure_i18n_filter() - self.template_env._i18n_filter_added = True - # Skip initialization check and directly try to get cached data try: # Recipe scanner will initialize cache if needed diff --git a/tests/routes/test_recipe_route_scaffolding.py b/tests/routes/test_recipe_route_scaffolding.py new file mode 100644 index 00000000..1f0e723d --- /dev/null +++ b/tests/routes/test_recipe_route_scaffolding.py @@ -0,0 +1,229 @@ +"""Smoke tests for the recipe routing scaffolding. + +The cases keep the registrar/controller contract aligned with +``docs/architecture/recipe_routes.md`` so future refactors can focus on handler +logic. +""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +import types +from collections import Counter +from pathlib import Path +from typing import Any, Awaitable, Callable, Dict + +import pytest +from aiohttp import web + + +REPO_ROOT = Path(__file__).resolve().parents[2] +PY_PACKAGE_PATH = REPO_ROOT / "py" + +spec = importlib.util.spec_from_file_location( + "py_local", + PY_PACKAGE_PATH / "__init__.py", + submodule_search_locations=[str(PY_PACKAGE_PATH)], +) +py_local = importlib.util.module_from_spec(spec) +assert spec.loader is not None +spec.loader.exec_module(py_local) +sys.modules.setdefault("py_local", py_local) + +base_routes_module = importlib.import_module("py_local.routes.base_recipe_routes") +recipe_routes_module = importlib.import_module("py_local.routes.recipe_routes") +registrar_module = importlib.import_module("py_local.routes.recipe_route_registrar") + + +@pytest.fixture(autouse=True) +def reset_service_registry(monkeypatch: pytest.MonkeyPatch): + """Ensure each test starts from a clean registry state.""" + + services_module = importlib.import_module("py_local.services.service_registry") + registry = services_module.ServiceRegistry + previous_services = dict(registry._services) + previous_locks = dict(registry._locks) + registry._services.clear() + registry._locks.clear() + try: + yield + finally: + registry._services = previous_services + registry._locks = previous_locks + + +def _make_stub_scanner(): + class _StubScanner: + def __init__(self): + self._cache = types.SimpleNamespace() + + async def _lora_get_cached_data(): # pragma: no cover - smoke hook + return None + + self._lora_scanner = types.SimpleNamespace( + get_cached_data=_lora_get_cached_data, + _hash_index=types.SimpleNamespace(_hash_to_path={}), + ) + + async def get_cached_data(self, force_refresh: bool = False): + return self._cache + + return _StubScanner() + + +def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPatch): + base_module = base_routes_module + services_module = importlib.import_module("py_local.services.service_registry") + registry = services_module.ServiceRegistry + server_i18n = importlib.import_module("py_local.services.server_i18n").server_i18n + + scanner = _make_stub_scanner() + civitai_client = object() + filter_calls = Counter() + + async def fake_get_recipe_scanner(): + return scanner + + async def fake_get_civitai_client(): + return civitai_client + + def fake_create_filter(): + filter_calls["create_filter"] += 1 + return object() + + monkeypatch.setattr(registry, "get_recipe_scanner", fake_get_recipe_scanner) + monkeypatch.setattr(registry, "get_civitai_client", fake_get_civitai_client) + monkeypatch.setattr(server_i18n, "create_template_filter", fake_create_filter) + + async def scenario(): + routes = base_module.BaseRecipeRoutes() + + await routes.attach_dependencies() + await routes.attach_dependencies() # idempotent + + assert routes.recipe_scanner is scanner + assert routes.lora_scanner is scanner._lora_scanner + assert routes.civitai_client is civitai_client + assert routes.template_env.filters["t"] is not None + assert filter_calls["create_filter"] == 1 + + asyncio.run(scenario()) + + +def test_register_startup_hooks_appends_once(): + routes = base_routes_module.BaseRecipeRoutes() + + app = web.Application() + routes.register_startup_hooks(app) + routes.register_startup_hooks(app) + + startup_bound_to_routes = [ + callback for callback in app.on_startup if getattr(callback, "__self__", None) is routes + ] + + assert routes.attach_dependencies in startup_bound_to_routes + assert routes.prewarm_cache in startup_bound_to_routes + assert len(startup_bound_to_routes) == 2 + + +def test_to_route_mapping_uses_handler_owner(monkeypatch: pytest.MonkeyPatch): + class DummyOwner: + async def render_page(self, request): + return web.Response(text="ok") + + async def list_recipes(self, request): # pragma: no cover - invoked via mapping + return web.json_response({}) + + class DummyRoutes(base_routes_module.BaseRecipeRoutes): + def get_handler_owner(self): # noqa: D401 - simple override for test + return DummyOwner() + + monkeypatch.setattr( + base_routes_module.BaseRecipeRoutes, + "_HANDLER_NAMES", + ("render_page", "list_recipes"), + ) + + routes = DummyRoutes() + mapping = routes.to_route_mapping() + + assert set(mapping.keys()) == {"render_page", "list_recipes"} + assert asyncio.iscoroutinefunction(mapping["render_page"]) + # Cached mapping reused on subsequent calls + assert routes.to_route_mapping() is mapping + + +def test_recipe_route_registrar_binds_every_route(): + class FakeRouter: + def __init__(self): + self.calls: list[tuple[str, str, Callable[..., Awaitable[Any]]]] = [] + + def add_get(self, path, handler): + self.calls.append(("GET", path, handler)) + + def add_post(self, path, handler): + self.calls.append(("POST", path, handler)) + + def add_put(self, path, handler): + self.calls.append(("PUT", path, handler)) + + def add_delete(self, path, handler): + self.calls.append(("DELETE", path, handler)) + + class FakeApp: + def __init__(self): + self.router = FakeRouter() + + app = FakeApp() + registrar = registrar_module.RecipeRouteRegistrar(app) + + handler_mapping = { + definition.handler_name: object() + for definition in registrar_module.ROUTE_DEFINITIONS + } + + registrar.register_routes(handler_mapping) + + assert { + (method, path) + for method, path, _ in app.router.calls + } == {(d.method, d.path) for d in registrar_module.ROUTE_DEFINITIONS} + + +def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPatch): + registered_mappings: list[Dict[str, Callable[..., Awaitable[Any]]]] = [] + + class DummyRegistrar: + def __init__(self, app): + self.app = app + + def register_routes(self, mapping): + registered_mappings.append(mapping) + + monkeypatch.setattr(recipe_routes_module, "RecipeRouteRegistrar", DummyRegistrar) + + expected_mapping = {name: object() for name in ("render_page", "list_recipes")} + + def fake_to_route_mapping(self): + return expected_mapping + + monkeypatch.setattr(base_routes_module.BaseRecipeRoutes, "to_route_mapping", fake_to_route_mapping) + monkeypatch.setattr( + base_routes_module.BaseRecipeRoutes, + "_HANDLER_NAMES", + tuple(expected_mapping.keys()), + ) + + app = web.Application() + recipe_routes_module.RecipeRoutes.setup_routes(app) + + assert registered_mappings == [expected_mapping] + recipe_callbacks = { + cb + for cb in app.on_startup + if isinstance(getattr(cb, "__self__", None), recipe_routes_module.RecipeRoutes) + } + assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes_module.RecipeRoutes} + assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"} From d033a374dd72fe81821e7b9d520e8a53e7c5c302 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 12:57:37 +0800 Subject: [PATCH 11/24] refactor(routes): split recipe handlers into dedicated classes --- py/routes/base_recipe_routes.py | 73 +- py/routes/handlers/recipe_handlers.py | 1347 ++++++++++++++ py/routes/recipe_routes.py | 1561 +---------------- tests/routes/test_recipe_route_scaffolding.py | 35 +- 4 files changed, 1440 insertions(+), 1576 deletions(-) create mode 100644 py/routes/handlers/recipe_handlers.py diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py index e2b726da..59d4e7ec 100644 --- a/py/routes/base_recipe_routes.py +++ b/py/routes/base_recipe_routes.py @@ -11,6 +11,15 @@ from ..config import config from ..services.server_i18n import server_i18n from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings +from .handlers.recipe_handlers import ( + RecipeAnalysisHandler, + RecipeHandlerSet, + RecipeListingHandler, + RecipeManagementHandler, + RecipePageView, + RecipeQueryHandler, + RecipeSharingHandler, +) from .recipe_route_registrar import ROUTE_DEFINITIONS logger = logging.getLogger(__name__) @@ -23,6 +32,8 @@ class BaseRecipeRoutes: definition.handler_name for definition in ROUTE_DEFINITIONS ) + template_name: str = "recipes.html" + def __init__(self) -> None: self.recipe_scanner = None self.lora_scanner = None @@ -36,6 +47,7 @@ class BaseRecipeRoutes: self._i18n_registered = False self._startup_hooks_registered = False + self._handler_set: RecipeHandlerSet | None = None self._handler_mapping: dict[str, Callable] | None = None async def attach_dependencies(self, app: web.Application | None = None) -> None: @@ -81,10 +93,9 @@ class BaseRecipeRoutes: """Return a mapping of handler name to coroutine for registrar binding.""" if self._handler_mapping is None: - owner = self.get_handler_owner() - self._handler_mapping = { - name: getattr(owner, name) for name in self._HANDLER_NAMES - } + handler_set = self._create_handler_set() + self._handler_set = handler_set + self._handler_mapping = handler_set.to_route_mapping() return self._handler_mapping # Internal helpers ------------------------------------------------- @@ -105,5 +116,57 @@ class BaseRecipeRoutes: def get_handler_owner(self): """Return the object supplying bound handler coroutines.""" - return self + if self._handler_set is None: + self._handler_set = self._create_handler_set() + return self._handler_set + + def _create_handler_set(self) -> RecipeHandlerSet: + recipe_scanner_getter = lambda: self.recipe_scanner + civitai_client_getter = lambda: self.civitai_client + + page_view = RecipePageView( + ensure_dependencies_ready=self.ensure_dependencies_ready, + settings_service=self.settings, + server_i18n=self.server_i18n, + template_env=self.template_env, + template_name=self.template_name, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + ) + listing = RecipeListingHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + ) + query = RecipeQueryHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + format_recipe_file_url=listing.format_recipe_file_url, + logger=logger, + ) + management = RecipeManagementHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + ) + analysis = RecipeAnalysisHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + logger=logger, + ) + sharing = RecipeSharingHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + ) + + return RecipeHandlerSet( + page_view=page_view, + listing=listing, + query=query, + management=management, + analysis=analysis, + sharing=sharing, + ) diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py new file mode 100644 index 00000000..8d8f96bf --- /dev/null +++ b/py/routes/handlers/recipe_handlers.py @@ -0,0 +1,1347 @@ +"""Dedicated handler objects for recipe-related routes.""" +from __future__ import annotations + +import asyncio +import base64 +import io +import json +import logging +import os +import tempfile +import time +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, Mapping, Optional + +import numpy as np +from aiohttp import web +from PIL import Image + +from ...config import config +from ...recipes import RecipeParserFactory +from ...services.downloader import get_downloader +from ...services.server_i18n import server_i18n as default_server_i18n +from ...services.settings_manager import SettingsManager +from ...utils.constants import CARD_PREVIEW_WIDTH +from ...utils.exif_utils import ExifUtils + +# Check if running in standalone mode +standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" + +if not standalone_mode: + from ...metadata_collector import get_metadata + from ...metadata_collector.metadata_processor import MetadataProcessor + from ...metadata_collector.metadata_registry import MetadataRegistry +else: # pragma: no cover - optional dependency path + get_metadata = None # type: ignore[assignment] + MetadataProcessor = None # type: ignore[assignment] + MetadataRegistry = None # type: ignore[assignment] + +Logger = logging.Logger +EnsureDependenciesCallable = Callable[[], Awaitable[None]] +RecipeScannerGetter = Callable[[], Any] +CivitaiClientGetter = Callable[[], Any] + + +@dataclass(frozen=True) +class RecipeHandlerSet: + """Group of handlers providing recipe route implementations.""" + + page_view: "RecipePageView" + listing: "RecipeListingHandler" + query: "RecipeQueryHandler" + management: "RecipeManagementHandler" + analysis: "RecipeAnalysisHandler" + sharing: "RecipeSharingHandler" + + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: + """Expose handler coroutines keyed by registrar handler names.""" + + return { + "render_page": self.page_view.render_page, + "list_recipes": self.listing.list_recipes, + "get_recipe": self.listing.get_recipe, + "analyze_uploaded_image": self.analysis.analyze_uploaded_image, + "analyze_local_image": self.analysis.analyze_local_image, + "save_recipe": self.management.save_recipe, + "delete_recipe": self.management.delete_recipe, + "get_top_tags": self.query.get_top_tags, + "get_base_models": self.query.get_base_models, + "share_recipe": self.sharing.share_recipe, + "download_shared_recipe": self.sharing.download_shared_recipe, + "get_recipe_syntax": self.query.get_recipe_syntax, + "update_recipe": self.management.update_recipe, + "reconnect_lora": self.management.reconnect_lora, + "find_duplicates": self.query.find_duplicates, + "bulk_delete": self.management.bulk_delete, + "save_recipe_from_widget": self.management.save_recipe_from_widget, + "get_recipes_for_lora": self.query.get_recipes_for_lora, + "scan_recipes": self.query.scan_recipes, + } + + +class RecipePageView: + """Render the recipe shell page.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + settings_service: SettingsManager, + server_i18n=default_server_i18n, + template_env, + template_name: str, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._settings = settings_service + self._server_i18n = server_i18n + self._template_env = template_env + self._template_name = template_name + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + + async def render_page(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: # pragma: no cover - defensive guard + raise RuntimeError("Recipe scanner not available") + + user_language = self._settings.get("language", "en") + self._server_i18n.set_locale(user_language) + + try: + await recipe_scanner.get_cached_data(force_refresh=False) + rendered = self._template_env.get_template(self._template_name).render( + recipes=[], + is_initializing=False, + settings=self._settings, + request=request, + t=self._server_i18n.get_translation, + ) + except Exception as cache_error: # pragma: no cover - logging path + self._logger.error("Error loading recipe cache data: %s", cache_error) + rendered = self._template_env.get_template(self._template_name).render( + is_initializing=True, + settings=self._settings, + request=request, + t=self._server_i18n.get_translation, + ) + return web.Response(text=rendered, content_type="text/html") + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error handling recipes request: %s", exc, exc_info=True) + return web.Response(text="Error loading recipes page", status=500) + + +class RecipeListingHandler: + """Provide listing and detail APIs for recipes.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + + async def list_recipes(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + page = int(request.query.get("page", "1")) + page_size = int(request.query.get("page_size", "20")) + sort_by = request.query.get("sort_by", "date") + search = request.query.get("search") + + search_options = { + "title": request.query.get("search_title", "true").lower() == "true", + "tags": request.query.get("search_tags", "true").lower() == "true", + "lora_name": request.query.get("search_lora_name", "true").lower() == "true", + "lora_model": request.query.get("search_lora_model", "true").lower() == "true", + } + + filters: Dict[str, list[str]] = {} + base_models = request.query.get("base_models") + if base_models: + filters["base_model"] = base_models.split(",") + + tags = request.query.get("tags") + if tags: + filters["tags"] = tags.split(",") + + lora_hash = request.query.get("lora_hash") + + result = await recipe_scanner.get_paginated_data( + page=page, + page_size=page_size, + sort_by=sort_by, + search=search, + filters=filters, + search_options=search_options, + lora_hash=lora_hash, + ) + + for item in result.get("items", []): + file_path = item.get("file_path") + if file_path: + item["file_url"] = self.format_recipe_file_url(file_path) + else: + item.setdefault("file_url", "/loras_static/images/no-preview.png") + item.setdefault("loras", []) + item.setdefault("base_model", "") + + return web.json_response(result) + except Exception as exc: + self._logger.error("Error retrieving recipes: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def get_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) + + if not recipe: + return web.json_response({"error": "Recipe not found"}, status=404) + return web.json_response(recipe) + except Exception as exc: + self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + def format_recipe_file_url(self, file_path: str) -> str: + try: + recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/") + normalized_path = file_path.replace(os.sep, "/") + if normalized_path.startswith(recipes_dir): + relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/") + return f"/loras_static/root1/preview/{relative_path}" + + file_name = os.path.basename(file_path) + return f"/loras_static/root1/preview/recipes/{file_name}" + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True) + return "/loras_static/images/no-preview.png" + + +class RecipeQueryHandler: + """Provide read-only insights on recipe data.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + format_recipe_file_url: Callable[[str], str], + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._format_recipe_file_url = format_recipe_file_url + self._logger = logger + + async def get_top_tags(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + limit = int(request.query.get("limit", "20")) + cache = await recipe_scanner.get_cached_data() + + tag_counts: Dict[str, int] = {} + for recipe in getattr(cache, "raw_data", []): + for tag in recipe.get("tags", []) or []: + tag_counts[tag] = tag_counts.get(tag, 0) + 1 + + sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()] + sorted_tags.sort(key=lambda entry: entry["count"], reverse=True) + return web.json_response({"success": True, "tags": sorted_tags[:limit]}) + except Exception as exc: + self._logger.error("Error retrieving top tags: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + cache = await recipe_scanner.get_cached_data() + + base_model_counts: Dict[str, int] = {} + for recipe in getattr(cache, "raw_data", []): + base_model = recipe.get("base_model") + if base_model: + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()] + sorted_models.sort(key=lambda entry: entry["count"], reverse=True) + return web.json_response({"success": True, "base_models": sorted_models}) + except Exception as exc: + self._logger.error("Error retrieving base models: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_recipes_for_lora(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + lora_hash = request.query.get("hash") + if not lora_hash: + return web.json_response({"success": False, "error": "Lora hash is required"}, status=400) + + cache = await recipe_scanner.get_cached_data() + matching_recipes = [] + for recipe in getattr(cache, "raw_data", []): + for lora in recipe.get("loras", []): + if lora.get("hash", "").lower() == lora_hash.lower(): + matching_recipes.append(recipe) + break + + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + for recipe in matching_recipes: + for lora in recipe.get("loras", []): + hash_value = (lora.get("hash") or "").lower() + if hash_value and lora_scanner is not None: + lora["inLibrary"] = lora_scanner.has_hash(hash_value) + lora["preview_url"] = lora_scanner.get_preview_url_by_hash(hash_value) + lora["localPath"] = lora_scanner.get_path_by_hash(hash_value) + if recipe.get("file_path"): + recipe["file_url"] = self._format_recipe_file_url(recipe["file_path"]) + else: + recipe["file_url"] = "/loras_static/images/no-preview.png" + + return web.json_response({"success": True, "recipes": matching_recipes}) + except Exception as exc: + self._logger.error("Error getting recipes for Lora: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def scan_recipes(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + self._logger.info("Manually triggering recipe cache rebuild") + await recipe_scanner.get_cached_data(force_refresh=True) + return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"}) + except Exception as exc: + self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def find_duplicates(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + duplicate_groups = await recipe_scanner.find_all_duplicate_recipes() + response_data = [] + + for fingerprint, recipe_ids in duplicate_groups.items(): + if len(recipe_ids) <= 1: + continue + + recipes = [] + for recipe_id in recipe_ids: + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) + if recipe: + recipes.append( + { + "id": recipe.get("id"), + "title": recipe.get("title"), + "file_url": recipe.get("file_url") + or self._format_recipe_file_url(recipe.get("file_path", "")), + "modified": recipe.get("modified"), + "created_date": recipe.get("created_date"), + "lora_count": len(recipe.get("loras", [])), + } + ) + + if len(recipes) >= 2: + recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True) + response_data.append( + { + "fingerprint": fingerprint, + "count": len(recipes), + "recipes": recipes, + } + ) + + response_data.sort(key=lambda entry: entry["count"], reverse=True) + return web.json_response({"success": True, "duplicate_groups": response_data}) + except Exception as exc: + self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_recipe_syntax(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + cache = await recipe_scanner.get_cached_data() + recipe = next( + (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), + None, + ) + if not recipe: + return web.json_response({"error": "Recipe not found"}, status=404) + + loras = recipe.get("loras", []) + if not loras: + return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) + + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + hash_index = getattr(lora_scanner, "_hash_index", None) + + lora_syntax_parts = [] + for lora in loras: + if lora.get("isDeleted", False): + continue + hash_value = (lora.get("hash") or "").lower() + if not hash_value or lora_scanner is None or not lora_scanner.has_hash(hash_value): + continue + + file_name = None + if hash_value and hash_index is not None and hasattr(hash_index, "_hash_to_path"): + file_path = hash_index._hash_to_path.get(hash_value) + if file_path: + file_name = os.path.splitext(os.path.basename(file_path))[0] + + if not file_name and lora.get("modelVersionId") and lora_scanner is not None: + all_loras = await lora_scanner.get_cached_data() + for cached_lora in getattr(all_loras, "raw_data", []): + civitai_info = cached_lora.get("civitai") + if civitai_info and civitai_info.get("id") == lora.get("modelVersionId"): + file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0] + break + + if not file_name: + file_name = lora.get("file_name", "unknown-lora") + + strength = lora.get("strength", 1.0) + lora_syntax_parts.append(f"") + + return web.json_response({"success": True, "syntax": " ".join(lora_syntax_parts)}) + except Exception as exc: + self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + +class RecipeManagementHandler: + """Handle create/update/delete style recipe operations.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + exif_utils=ExifUtils, + card_preview_width: int = CARD_PREVIEW_WIDTH, + metadata_collector: Optional[Callable[[], Any]] = get_metadata, + metadata_processor_cls: Optional[type] = MetadataProcessor, + metadata_registry_cls: Optional[type] = MetadataRegistry, + standalone_mode: bool = standalone_mode, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + self._exif_utils = exif_utils + self._card_preview_width = card_preview_width + self._metadata_collector = metadata_collector + self._metadata_processor_cls = metadata_processor_cls + self._metadata_registry_cls = metadata_registry_cls + self._standalone_mode = standalone_mode + + async def save_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + reader = await request.multipart() + + image: Optional[bytes] = None + image_base64: Optional[str] = None + name: Optional[str] = None + tags: list[str] = [] + metadata: Dict[str, Any] | None = None + + while True: + field = await reader.next() + if field is None: + break + + if field.name == "image": + image_chunks = bytearray() + while True: + chunk = await field.read_chunk() + if not chunk: + break + image_chunks.extend(chunk) + image = bytes(image_chunks) + elif field.name == "image_base64": + image_base64 = await field.text() + elif field.name == "name": + name = await field.text() + elif field.name == "tags": + tags_text = await field.text() + try: + parsed_tags = json.loads(tags_text) + tags = parsed_tags if isinstance(parsed_tags, list) else [] + except Exception: + tags = [] + elif field.name == "metadata": + metadata_text = await field.text() + try: + metadata = json.loads(metadata_text) + except Exception: + metadata = {} + + missing_fields = [] + if not name: + missing_fields.append("name") + if not metadata: + missing_fields.append("metadata") + if missing_fields: + return web.json_response( + {"error": f"Missing required fields: {', '.join(missing_fields)}"}, + status=400, + ) + + if image is None: + if image_base64: + try: + if "," in image_base64: + image_base64 = image_base64.split(",", 1)[1] + image = base64.b64decode(image_base64) + except Exception as exc: + return web.json_response({"error": f"Invalid base64 image data: {exc}"}, status=400) + else: + return web.json_response({"error": "No image data provided"}, status=400) + + recipes_dir = recipe_scanner.recipes_dir + os.makedirs(recipes_dir, exist_ok=True) + + import uuid + + recipe_id = str(uuid.uuid4()) + optimized_image, extension = self._exif_utils.optimize_image( + image_data=image, + target_width=self._card_preview_width, + format="webp", + quality=85, + preserve_metadata=True, + ) + + image_filename = f"{recipe_id}{extension}" + image_path = os.path.join(recipes_dir, image_filename) + with open(image_path, "wb") as file_obj: + file_obj.write(optimized_image) + + current_time = time.time() + loras_data = [] + for lora in metadata.get("loras", []): + loras_data.append( + { + "file_name": lora.get("file_name", "") + or ( + os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] + if lora.get("localPath") + else "" + ), + "hash": (lora.get("hash") or "").lower(), + "strength": float(lora.get("weight", 1.0)), + "modelVersionId": lora.get("id", 0), + "modelName": lora.get("name", ""), + "modelVersionName": lora.get("version", ""), + "isDeleted": lora.get("isDeleted", False), + "exclude": lora.get("exclude", False), + } + ) + + gen_params = metadata.get("gen_params", {}) + if not gen_params and "raw_metadata" in metadata: + raw_metadata = metadata.get("raw_metadata", {}) + gen_params = { + "prompt": raw_metadata.get("prompt", ""), + "negative_prompt": raw_metadata.get("negative_prompt", ""), + "checkpoint": raw_metadata.get("checkpoint", {}), + "steps": raw_metadata.get("steps", ""), + "sampler": raw_metadata.get("sampler", ""), + "cfg_scale": raw_metadata.get("cfg_scale", ""), + "seed": raw_metadata.get("seed", ""), + "size": raw_metadata.get("size", ""), + "clip_skip": raw_metadata.get("clip_skip", ""), + } + + from ...utils.utils import calculate_recipe_fingerprint + + fingerprint = calculate_recipe_fingerprint(loras_data) + + recipe_data = { + "id": recipe_id, + "file_path": image_path, + "title": name, + "modified": current_time, + "created_date": current_time, + "base_model": metadata.get("base_model", ""), + "loras": loras_data, + "gen_params": gen_params, + "fingerprint": fingerprint, + } + + if tags: + recipe_data["tags"] = tags + + if metadata.get("source_path"): + recipe_data["source_path"] = metadata.get("source_path") + + json_filename = f"{recipe_id}.recipe.json" + json_path = os.path.join(recipes_dir, json_filename) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + matching_recipes = [] + if fingerprint: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + if recipe_id in matching_recipes: + matching_recipes.remove(recipe_id) + + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + cache.raw_data.append(recipe_data) + asyncio.create_task(cache.resort()) + self._logger.info("Added recipe %s to cache", recipe_id) + + return web.json_response( + { + "success": True, + "recipe_id": recipe_id, + "image_path": image_path, + "json_path": json_path, + "matching_recipes": matching_recipes, + } + ) + except Exception as exc: + self._logger.error("Error saving recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def delete_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir or not os.path.exists(recipes_dir): + return web.json_response({"error": "Recipes directory not found"}, status=404) + + recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + return web.json_response({"error": "Recipe not found"}, status=404) + + with open(recipe_json_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + + image_path = recipe_data.get("file_path") + os.remove(recipe_json_path) + self._logger.info("Deleted recipe JSON file: %s", recipe_json_path) + + if image_path and os.path.exists(image_path): + os.remove(image_path) + self._logger.info("Deleted recipe image: %s", image_path) + + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + cache.raw_data = [ + item for item in cache.raw_data if str(item.get("id", "")) != recipe_id + ] + asyncio.create_task(cache.resort()) + self._logger.info("Removed recipe %s from cache", recipe_id) + + return web.json_response({"success": True, "message": "Recipe deleted successfully"}) + except Exception as exc: + self._logger.error("Error deleting recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def update_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + data = await request.json() + + if not any( + key in data for key in ("title", "tags", "source_path", "preview_nsfw_level") + ): + return web.json_response( + { + "error": ( + "At least one field to update must be provided (title or tags or " + "source_path or preview_nsfw_level)" + ) + }, + status=400, + ) + + success = await recipe_scanner.update_recipe_metadata(recipe_id, data) + if not success: + return web.json_response({"error": "Recipe not found or update failed"}, status=404) + + return web.json_response({"success": True, "recipe_id": recipe_id, "updates": data}) + except Exception as exc: + self._logger.error("Error updating recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def reconnect_lora(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + required_fields = ["recipe_id", "lora_index", "target_name"] + for field in required_fields: + if field not in data: + return web.json_response({"error": f"Missing required field: {field}"}, status=400) + + recipe_id = data["recipe_id"] + lora_index = int(data["lora_index"]) + target_name = data["target_name"] + + recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_path): + return web.json_response({"error": "Recipe not found"}, status=404) + + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + target_lora = None if lora_scanner is None else await lora_scanner.get_model_info_by_name(target_name) + if not target_lora: + return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) + + with open(recipe_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + + loras = recipe_data.get("loras", []) + lora = loras[lora_index] if lora_index < len(loras) else None + if lora is None: + return web.json_response({"error": "LoRA index out of range in recipe"}, status=404) + + lora["isDeleted"] = False + lora["exclude"] = False + lora["file_name"] = target_name + if "sha256" in target_lora: + lora["hash"] = target_lora["sha256"].lower() + if target_lora.get("civitai"): + lora["modelName"] = target_lora["civitai"]["model"]["name"] + lora["modelVersionName"] = target_lora["civitai"]["name"] + lora["modelVersionId"] = target_lora["civitai"]["id"] + + from ...utils.utils import calculate_recipe_fingerprint + + recipe_data["fingerprint"] = calculate_recipe_fingerprint(recipe_data.get("loras", [])) + + with open(recipe_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + updated_lora = dict(lora) + updated_lora["inLibrary"] = True + updated_lora["preview_url"] = config.get_preview_static_url(target_lora["preview_url"]) + updated_lora["localPath"] = target_lora["file_path"] + + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + for cache_item in cache.raw_data: + if cache_item.get("id") == recipe_id: + cache_item["loras"] = recipe_data["loras"] + cache_item["fingerprint"] = recipe_data["fingerprint"] + asyncio.create_task(cache.resort()) + break + + image_path = recipe_data.get("file_path") + if image_path and os.path.exists(image_path): + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + matching_recipes = [] + if "fingerprint" in recipe_data: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"]) + if recipe_id in matching_recipes: + matching_recipes.remove(recipe_id) + + return web.json_response( + { + "success": True, + "recipe_id": recipe_id, + "updated_lora": updated_lora, + "matching_recipes": matching_recipes, + } + ) + except Exception as exc: + self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def bulk_delete(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + recipe_ids = data.get("recipe_ids", []) + if not recipe_ids: + return web.json_response( + {"success": False, "error": "No recipe IDs provided"}, + status=400, + ) + + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir or not os.path.exists(recipes_dir): + return web.json_response( + {"success": False, "error": "Recipes directory not found"}, + status=404, + ) + + deleted_recipes: list[str] = [] + failed_recipes: list[Dict[str, Any]] = [] + + for recipe_id in recipe_ids: + recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"}) + continue + + try: + with open(recipe_json_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + image_path = recipe_data.get("file_path") + os.remove(recipe_json_path) + if image_path and os.path.exists(image_path): + os.remove(image_path) + deleted_recipes.append(recipe_id) + except Exception as exc: + failed_recipes.append({"id": recipe_id, "reason": str(exc)}) + + cache = getattr(recipe_scanner, "_cache", None) + if deleted_recipes and cache is not None: + cache.raw_data = [item for item in cache.raw_data if item.get("id") not in deleted_recipes] + asyncio.create_task(cache.resort()) + self._logger.info("Removed %s recipes from cache", len(deleted_recipes)) + + return web.json_response( + { + "success": True, + "deleted": deleted_recipes, + "failed": failed_recipes, + "total_deleted": len(deleted_recipes), + "total_failed": len(failed_recipes), + } + ) + except Exception as exc: + self._logger.error("Error performing bulk delete: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def save_recipe_from_widget(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + if self._metadata_collector is None or self._metadata_processor_cls is None: + return web.json_response({"error": "Metadata collection not available"}, status=400) + + raw_metadata = self._metadata_collector() + metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata) + if not metadata_dict: + return web.json_response({"error": "No generation metadata found"}, status=400) + + if not self._standalone_mode and self._metadata_registry_cls is not None: + metadata_registry = self._metadata_registry_cls() + latest_image = metadata_registry.get_first_decoded_image() + else: + latest_image = None + + if latest_image is None: + return web.json_response( + {"error": "No recent images found to use for recipe. Try generating an image first."}, + status=400, + ) + + self._logger.debug("Image type: %s", type(latest_image)) + + try: + if isinstance(latest_image, tuple): + tensor_image = latest_image[0] if latest_image else None + if tensor_image is None: + return web.json_response({"error": "Empty image tuple received"}, status=400) + else: + tensor_image = latest_image + + if hasattr(tensor_image, "shape"): + shape_info = tensor_image.shape + self._logger.debug("Tensor shape: %s, dtype: %s", shape_info, tensor_image.dtype) + + import torch # type: ignore[import-not-found] + + if isinstance(tensor_image, torch.Tensor): + image_np = tensor_image.cpu().numpy() + else: + image_np = np.array(tensor_image) + + while len(image_np.shape) > 3: + image_np = image_np[0] + + if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + + if len(image_np.shape) == 3 and image_np.shape[2] == 3: + pil_image = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format="PNG") + image_bytes = img_byte_arr.getvalue() + else: + return web.json_response( + {"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, + status=400, + ) + except Exception as exc: + self._logger.error("Error processing image data: %s", exc, exc_info=True) + return web.json_response({"error": f"Error processing image: {exc}"}, status=400) + + lora_stack = metadata_dict.get("loras", "") + import re + + lora_matches = re.findall(r"]+)>", lora_stack) + if not lora_matches: + return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400) + + loras_for_name = lora_matches[:3] + recipe_name_parts = [] + for name, strength in loras_for_name: + recipe_name_parts.append(f"{name.strip()}-{float(strength):.2f}") + recipe_name = "_".join(recipe_name_parts) + + recipe_name = recipe_name or "recipe" + + recipes_dir = recipe_scanner.recipes_dir + os.makedirs(recipes_dir, exist_ok=True) + + import uuid + + recipe_id = str(uuid.uuid4()) + image_filename = f"{recipe_id}.png" + image_path = os.path.join(recipes_dir, image_filename) + with open(image_path, "wb") as file_obj: + file_obj.write(image_bytes) + + loras_data = [] + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + base_model_counts: Dict[str, int] = {} + + for name, strength in lora_matches: + lora_info = None + if lora_scanner is not None: + lora_info = await lora_scanner.get_model_info_by_name(name) + lora_data = { + "file_name": name, + "strength": float(strength), + "hash": (lora_info.get("sha256") or "").lower() if lora_info else "", + "modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0, + "modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "", + "modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "", + "isDeleted": False, + "exclude": False, + } + loras_data.append(lora_data) + + if lora_info and "base_model" in lora_info: + base_model = lora_info["base_model"] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + most_common_base_model = "" + if base_model_counts: + most_common_base_model = max(base_model_counts.items(), key=lambda item: item[1])[0] + + recipe_data = { + "id": recipe_id, + "file_path": image_path, + "title": recipe_name, + "modified": time.time(), + "created_date": time.time(), + "base_model": most_common_base_model, + "loras": loras_data, + "checkpoint": metadata_dict.get("checkpoint", ""), + "gen_params": { + key: value + for key, value in metadata_dict.items() + if key not in ["checkpoint", "loras"] + }, + "loras_stack": lora_stack, + } + + json_filename = f"{recipe_id}.recipe.json" + json_path = os.path.join(recipes_dir, json_filename) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + cache.raw_data.append(recipe_data) + asyncio.create_task(cache.resort()) + self._logger.info("Added recipe %s to cache", recipe_id) + + return web.json_response( + { + "success": True, + "recipe_id": recipe_id, + "image_path": image_path, + "json_path": json_path, + "recipe_name": recipe_name, + } + ) + except Exception as exc: + self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + +class RecipeAnalysisHandler: + """Analyze images to extract recipe metadata.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + civitai_client_getter: CivitaiClientGetter, + logger: Logger, + exif_utils=ExifUtils, + recipe_parser_factory=RecipeParserFactory, + downloader_factory=get_downloader, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._civitai_client_getter = civitai_client_getter + self._logger = logger + self._exif_utils = exif_utils + self._recipe_parser_factory = recipe_parser_factory + self._downloader_factory = downloader_factory + + async def analyze_uploaded_image(self, request: web.Request) -> web.Response: + temp_path: Optional[str] = None + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + civitai_client = self._civitai_client_getter() + if recipe_scanner is None or civitai_client is None: + raise RuntimeError("Required services unavailable") + + content_type = request.headers.get("Content-Type", "") + is_url_mode = False + metadata: Optional[Dict[str, Any]] = None + + if "multipart/form-data" in content_type: + reader = await request.multipart() + field = await reader.next() + if field is None or field.name != "image": + return web.json_response({"error": "No image field found", "loras": []}, status=400) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + while True: + chunk = await field.read_chunk() + if not chunk: + break + temp_file.write(chunk) + temp_path = temp_file.name + elif "application/json" in content_type: + data = await request.json() + url = data.get("url") + is_url_mode = True + if not url: + return web.json_response({"error": "No URL provided", "loras": []}, status=400) + + import re + + civitai_image_match = re.match(r"https://civitai\.com/images/(\d+)", url) + if civitai_image_match: + image_id = civitai_image_match.group(1) + image_info = await civitai_client.get_image_info(image_id) + if not image_info: + return web.json_response( + {"error": "Failed to fetch image information from Civitai", "loras": []}, + status=400, + ) + image_url = image_info.get("url") + if not image_url: + return web.json_response( + {"error": "No image URL found in Civitai response", "loras": []}, + status=400, + ) + + downloader = await self._downloader_factory() + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + temp_path = temp_file.name + + success, result = await downloader.download_file( + image_url, + temp_path, + use_auth=False, + ) + if not success: + return web.json_response( + {"error": f"Failed to download image from URL: {result}", "loras": []}, + status=400, + ) + metadata = image_info.get("meta") if "meta" in image_info else None + else: + return web.json_response({"error": "Unsupported content type", "loras": []}, status=400) + + if metadata is None and temp_path: + metadata = self._exif_utils.extract_image_metadata(temp_path) + + if not metadata: + response: Dict[str, Any] = {"error": "No metadata found in this image", "loras": []} + if is_url_mode and temp_path: + with open(temp_path, "rb") as image_file: + response["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") + return web.json_response(response, status=200) + + parser = self._recipe_parser_factory.create_parser(metadata) + if parser is None: + response = {"error": "No parser found for this image", "loras": []} + if is_url_mode and temp_path: + with open(temp_path, "rb") as image_file: + response["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") + return web.json_response(response, status=200) + + result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner) + + if is_url_mode and temp_path: + with open(temp_path, "rb") as image_file: + result["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") + + if "error" in result and not result.get("loras"): + return web.json_response(result, status=200) + + from ...utils.utils import calculate_recipe_fingerprint + + fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) + result["fingerprint"] = fingerprint + + matching_recipes = [] + if fingerprint: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + + result["matching_recipes"] = matching_recipes + return web.json_response(result) + except Exception as exc: + self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True) + return web.json_response({"error": str(exc), "loras": []}, status=500) + finally: + if temp_path and os.path.exists(temp_path): + try: + os.unlink(temp_path) + except Exception as cleanup_exc: # pragma: no cover - logging path + self._logger.error("Error deleting temporary file: %s", cleanup_exc) + + async def analyze_local_image(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + file_path = data.get("path") + if not file_path: + return web.json_response({"error": "No file path provided", "loras": []}, status=400) + + file_path = os.path.normpath(file_path.strip('"').strip("'")) + if not os.path.isfile(file_path): + return web.json_response({"error": "File not found", "loras": []}, status=404) + + metadata = self._exif_utils.extract_image_metadata(file_path) + if not metadata: + with open(file_path, "rb") as image_file: + image_base64 = base64.b64encode(image_file.read()).decode("utf-8") + return web.json_response( + {"error": "No metadata found in this image", "loras": [], "image_base64": image_base64}, + status=200, + ) + + parser = self._recipe_parser_factory.create_parser(metadata) + if parser is None: + with open(file_path, "rb") as image_file: + image_base64 = base64.b64encode(image_file.read()).decode("utf-8") + return web.json_response( + {"error": "No parser found for this image", "loras": [], "image_base64": image_base64}, + status=200, + ) + + result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner) + with open(file_path, "rb") as image_file: + result["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") + + if "error" in result and not result.get("loras"): + return web.json_response(result, status=200) + + from ...utils.utils import calculate_recipe_fingerprint + + fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) + result["fingerprint"] = fingerprint + + matching_recipes = [] + if fingerprint: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + result["matching_recipes"] = matching_recipes + + return web.json_response(result) + except Exception as exc: + self._logger.error("Error analyzing local image: %s", exc, exc_info=True) + return web.json_response({"error": str(exc), "loras": []}, status=500) + + +class RecipeSharingHandler: + """Serve endpoints related to recipe sharing.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + self._shared_recipes: Dict[str, Dict[str, Any]] = {} + + async def share_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + cache = await recipe_scanner.get_cached_data() + recipe = next( + (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), + None, + ) + if not recipe: + return web.json_response({"error": "Recipe not found"}, status=404) + + image_path = recipe.get("file_path") + if not image_path or not os.path.exists(image_path): + return web.json_response({"error": "Recipe image not found"}, status=404) + + import shutil + + ext = os.path.splitext(image_path)[1] + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file: + temp_path = temp_file.name + shutil.copy2(image_path, temp_path) + processed_path = temp_path + + timestamp = int(time.time()) + url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}" + self._shared_recipes[recipe_id] = { + "path": processed_path, + "timestamp": timestamp, + "expires": time.time() + 300, + } + self._cleanup_shared_recipes() + + filename = f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}{ext}" + return web.json_response({"success": True, "download_url": url_path, "filename": filename}) + except Exception as exc: + self._logger.error("Error sharing recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def download_shared_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + shared_info = self._shared_recipes.get(recipe_id) + if not shared_info: + return web.json_response({"error": "Shared recipe not found or expired"}, status=404) + + file_path = shared_info["path"] + if not os.path.exists(file_path): + return web.json_response({"error": "Shared recipe file not found"}, status=404) + + cache = await recipe_scanner.get_cached_data() + recipe = next( + (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), + None, + ) + filename_base = ( + f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" + if recipe + else recipe_id + ) + ext = os.path.splitext(file_path)[1] + download_filename = f"{filename_base}{ext}" + + return web.FileResponse( + file_path, + headers={"Content-Disposition": f'attachment; filename="{download_filename}"'}, + ) + except Exception as exc: + self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + def _cleanup_shared_recipes(self) -> None: + current_time = time.time() + expired_ids = [ + recipe_id + for recipe_id, info in self._shared_recipes.items() + if current_time > info.get("expires", 0) + ] + + for recipe_id in expired_ids: + try: + file_path = self._shared_recipes[recipe_id]["path"] + if os.path.exists(file_path): + os.unlink(file_path) + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc) + finally: + self._shared_recipes.pop(recipe_id, None) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 55208552..2c233d01 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,41 +1,16 @@ -import os -import time -import base64 -import numpy as np -from PIL import Image -import io -import logging +"""Concrete recipe route configuration.""" + from aiohttp import web -from typing import Dict -import tempfile -import json -import asyncio -import sys from .base_recipe_routes import BaseRecipeRoutes from .recipe_route_registrar import RecipeRouteRegistrar -from ..utils.exif_utils import ExifUtils -from ..recipes import RecipeParserFactory -from ..utils.constants import CARD_PREVIEW_WIDTH -from ..config import config -from ..services.downloader import get_downloader - -# Check if running in standalone mode -standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" - -# Only import MetadataRegistry in non-standalone mode -if not standalone_mode: - # Import metadata_collector functions and classes conditionally - from ..metadata_collector import get_metadata # Add MetadataCollector import - from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import - from ..metadata_collector.metadata_registry import MetadataRegistry - -logger = logging.getLogger(__name__) class RecipeRoutes(BaseRecipeRoutes): """API route handlers for Recipe management.""" + template_name = "recipes.html" + @classmethod def setup_routes(cls, app: web.Application): """Register API routes using the declarative registrar.""" @@ -44,1531 +19,3 @@ class RecipeRoutes(BaseRecipeRoutes): registrar = RecipeRouteRegistrar(app) registrar.register_routes(routes.to_route_mapping()) routes.register_startup_hooks(app) - - async def render_page(self, request: web.Request) -> web.Response: - """Handle GET /loras/recipes request.""" - try: - await self.ensure_dependencies_ready() - - # 获取用户语言设置 - user_language = self.settings.get('language', 'en') - - # 设置服务端i18n语言 - self.server_i18n.set_locale(user_language) - - # Skip initialization check and directly try to get cached data - try: - # Recipe scanner will initialize cache if needed - await self.recipe_scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=[], # Frontend will load recipes via API - is_initializing=False, - settings=self.settings, - request=request, - # 添加服务端翻译函数 - t=self.server_i18n.get_translation, - ) - except Exception as cache_error: - logger.error(f"Error loading recipe cache data: {cache_error}") - # Still keep error handling - show initializing page on error - template = self.template_env.get_template('recipes.html') - rendered = template.render( - is_initializing=True, - settings=self.settings, - request=request, - # 添加服务端翻译函数 - t=self.server_i18n.get_translation, - ) - logger.info("Recipe cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - - except Exception as e: - logger.error(f"Error handling recipes request: {e}", exc_info=True) - return web.Response( - text="Error loading recipes page", - status=500 - ) - - async def list_recipes(self, request: web.Request) -> web.Response: - """API endpoint for getting paginated recipes.""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Get query parameters with defaults - page = int(request.query.get('page', '1')) - page_size = int(request.query.get('page_size', '20')) - sort_by = request.query.get('sort_by', 'date') - search = request.query.get('search', None) - - # Get search options (renamed for better clarity) - search_title = request.query.get('search_title', 'true').lower() == 'true' - search_tags = request.query.get('search_tags', 'true').lower() == 'true' - search_lora_name = request.query.get('search_lora_name', 'true').lower() == 'true' - search_lora_model = request.query.get('search_lora_model', 'true').lower() == 'true' - - # Get filter parameters - base_models = request.query.get('base_models', None) - tags = request.query.get('tags', None) - - # New parameter: get LoRA hash filter - lora_hash = request.query.get('lora_hash', None) - - # Parse filter parameters - filters = {} - if base_models: - filters['base_model'] = base_models.split(',') - if tags: - filters['tags'] = tags.split(',') - - # Add search options to filters - search_options = { - 'title': search_title, - 'tags': search_tags, - 'lora_name': search_lora_name, - 'lora_model': search_lora_model - } - - # Get paginated data with the new lora_hash parameter - result = await self.recipe_scanner.get_paginated_data( - page=page, - page_size=page_size, - sort_by=sort_by, - search=search, - filters=filters, - search_options=search_options, - lora_hash=lora_hash - ) - - # Format the response data with static URLs for file paths - for item in result['items']: - # Always ensure file_url is set - if 'file_path' in item: - item['file_url'] = self._format_recipe_file_url(item['file_path']) - else: - item['file_url'] = '/loras_static/images/no-preview.png' - - # 确保 loras 数组存在 - if 'loras' not in item: - item['loras'] = [] - - # 确保有 base_model 字段 - if 'base_model' not in item: - item['base_model'] = "" - - return web.json_response(result) - except Exception as e: - logger.error(f"Error retrieving recipes: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_recipe(self, request: web.Request) -> web.Response: - """Get detailed information about a specific recipe.""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - recipe_id = request.match_info['recipe_id'] - - # Use the new get_recipe_by_id method from recipe_scanner - recipe = await self.recipe_scanner.get_recipe_by_id(recipe_id) - - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - return web.json_response(recipe) - except Exception as e: - logger.error(f"Error retrieving recipe details: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _format_recipe_file_url(self, file_path: str) -> str: - """Format file path for recipe image as a URL""" - try: - # Return the file URL directly for the first lora root's preview - recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/') - if file_path.replace(os.sep, '/').startswith(recipes_dir): - relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/') - return f"/loras_static/root1/preview/{relative_path}" - - # If not in recipes dir, try to create a valid URL from the file path - file_name = os.path.basename(file_path) - return f"/loras_static/root1/preview/recipes/{file_name}" - except Exception as e: - logger.error(f"Error formatting recipe file URL: {e}", exc_info=True) - return '/loras_static/images/no-preview.png' # Return default image on error - - def _format_recipe_data(self, recipe: Dict) -> Dict: - """Format recipe data for API response""" - formatted = {**recipe} # Copy all fields - - # Format file paths to URLs - if 'file_path' in formatted: - formatted['file_url'] = self._format_recipe_file_url(formatted['file_path']) - - # Format dates for display - for date_field in ['created_date', 'modified']: - if date_field in formatted: - formatted[f"{date_field}_formatted"] = self._format_timestamp(formatted[date_field]) - - return formatted - - def _format_timestamp(self, timestamp: float) -> str: - """Format timestamp for display""" - from datetime import datetime - return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') - - async def analyze_uploaded_image(self, request: web.Request) -> web.Response: - """Analyze an uploaded image or URL for recipe metadata.""" - temp_path = None - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Check if request contains multipart data (image) or JSON data (url) - content_type = request.headers.get('Content-Type', '') - - is_url_mode = False - metadata = None # Initialize metadata variable - - if 'multipart/form-data' in content_type: - # Handle image upload - reader = await request.multipart() - field = await reader.next() - - if field.name != 'image': - return web.json_response({ - "error": "No image field found", - "loras": [] - }, status=400) - - # Create a temporary file to store the uploaded image - with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: - while True: - chunk = await field.read_chunk() - if not chunk: - break - temp_file.write(chunk) - temp_path = temp_file.name - - elif 'application/json' in content_type: - # Handle URL input - data = await request.json() - url = data.get('url') - is_url_mode = True - - if not url: - return web.json_response({ - "error": "No URL provided", - "loras": [] - }, status=400) - - # Check if this is a Civitai image URL - import re - civitai_image_match = re.match(r'https://civitai\.com/images/(\d+)', url) - - if civitai_image_match: - # Extract image ID and fetch image info using get_image_info - image_id = civitai_image_match.group(1) - image_info = await self.civitai_client.get_image_info(image_id) - - if not image_info: - return web.json_response({ - "error": "Failed to fetch image information from Civitai", - "loras": [] - }, status=400) - - # Get image URL from response - image_url = image_info.get('url') - if not image_url: - return web.json_response({ - "error": "No image URL found in Civitai response", - "loras": [] - }, status=400) - - # Download image using unified downloader - downloader = await get_downloader() - # Create a temporary file to save the downloaded image - with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: - temp_path = temp_file.name - - success, result = await downloader.download_file( - image_url, - temp_path, - use_auth=False # Image downloads typically don't need auth - ) - - if not success: - return web.json_response({ - "error": f"Failed to download image from URL: {result}", - "loras": [] - }, status=400) - - # Use meta field from image_info as metadata - if 'meta' in image_info: - metadata = image_info['meta'] - - # If metadata wasn't obtained from Civitai API, extract it from the image - if metadata is None: - # Extract metadata from the image using ExifUtils - metadata = ExifUtils.extract_image_metadata(temp_path) - - # If no metadata found, return a more specific error - if not metadata: - result = { - "error": "No metadata found in this image", - "loras": [] # Return empty loras array to prevent client-side errors - } - - # For URL mode, include the image data as base64 - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response(result, status=200) - - # Use the parser factory to get the appropriate parser - parser = RecipeParserFactory.create_parser(metadata) - - if parser is None: - result = { - "error": "No parser found for this image", - "loras": [] # Return empty loras array to prevent client-side errors - } - - # For URL mode, include the image data as base64 - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response(result, status=200) - - # Parse the metadata - result = await parser.parse_metadata( - metadata, - recipe_scanner=self.recipe_scanner - ) - - # For URL mode, include the image data as base64 - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - # Check for errors - if "error" in result and not result.get("loras"): - return web.json_response(result, status=200) - - # Calculate fingerprint from parsed loras - from ..utils.utils import calculate_recipe_fingerprint - fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) - - # Add fingerprint to result - result["fingerprint"] = fingerprint - - # Find matching recipes with the same fingerprint - matching_recipes = [] - if fingerprint: - matching_recipes = await self.recipe_scanner.find_recipes_by_fingerprint(fingerprint) - - # Add matching recipes to result - result["matching_recipes"] = matching_recipes - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error analyzing recipe image: {e}", exc_info=True) - return web.json_response({ - "error": str(e), - "loras": [] # Return empty loras array to prevent client-side errors - }, status=500) - finally: - # Clean up the temporary file in the finally block - if temp_path and os.path.exists(temp_path): - try: - os.unlink(temp_path) - except Exception as e: - logger.error(f"Error deleting temporary file: {e}") - - async def analyze_local_image(self, request: web.Request) -> web.Response: - """Analyze a local image file for recipe metadata""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Get JSON data from request - data = await request.json() - file_path = data.get('path') - - if not file_path: - return web.json_response({ - 'error': 'No file path provided', - 'loras': [] - }, status=400) - - # Normalize file path for cross-platform compatibility - file_path = os.path.normpath(file_path.strip('"').strip("'")) - - # Validate that the file exists - if not os.path.isfile(file_path): - return web.json_response({ - 'error': 'File not found', - 'loras': [] - }, status=404) - - # Extract metadata from the image using ExifUtils - metadata = ExifUtils.extract_image_metadata(file_path) - - # If no metadata found, return error - if not metadata: - # Get base64 image data - with open(file_path, "rb") as image_file: - image_base64 = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response({ - "error": "No metadata found in this image", - "loras": [], # Return empty loras array to prevent client-side errors - "image_base64": image_base64 - }, status=200) - - # Use the parser factory to get the appropriate parser - parser = RecipeParserFactory.create_parser(metadata) - - if parser is None: - # Get base64 image data - with open(file_path, "rb") as image_file: - image_base64 = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response({ - "error": "No parser found for this image", - "loras": [], # Return empty loras array to prevent client-side errors - "image_base64": image_base64 - }, status=200) - - # Parse the metadata - result = await parser.parse_metadata( - metadata, - recipe_scanner=self.recipe_scanner - ) - - # Add base64 image data to result - with open(file_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - # Check for errors - if "error" in result and not result.get("loras"): - return web.json_response(result, status=200) - - # Calculate fingerprint from parsed loras - from ..utils.utils import calculate_recipe_fingerprint - fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) - - # Add fingerprint to result - result["fingerprint"] = fingerprint - - # Find matching recipes with the same fingerprint - matching_recipes = [] - if fingerprint: - matching_recipes = await self.recipe_scanner.find_recipes_by_fingerprint(fingerprint) - - # Add matching recipes to result - result["matching_recipes"] = matching_recipes - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error analyzing local image: {e}", exc_info=True) - return web.json_response({ - 'error': str(e), - 'loras': [] # Return empty loras array to prevent client-side errors - }, status=500) - - async def save_recipe(self, request: web.Request) -> web.Response: - """Save a recipe to the recipes folder""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - reader = await request.multipart() - - # Process form data - image = None - image_base64 = None - image_url = None - name = None - tags = [] - metadata = None - - while True: - field = await reader.next() - if field is None: - break - - if field.name == 'image': - # Read image data - image_data = b'' - while True: - chunk = await field.read_chunk() - if not chunk: - break - image_data += chunk - image = image_data - - elif field.name == 'image_base64': - # Get base64 image data - image_base64 = await field.text() - - elif field.name == 'image_url': - # Get image URL - image_url = await field.text() - - elif field.name == 'name': - name = await field.text() - - elif field.name == 'tags': - tags_text = await field.text() - try: - tags = json.loads(tags_text) - except: - tags = [] - - elif field.name == 'metadata': - metadata_text = await field.text() - try: - metadata = json.loads(metadata_text) - except: - metadata = {} - - missing_fields = [] - if not name: - missing_fields.append("name") - if not metadata: - missing_fields.append("metadata") - if missing_fields: - return web.json_response({"error": f"Missing required fields: {', '.join(missing_fields)}"}, status=400) - - # Handle different image sources - if not image: - if image_base64: - # Convert base64 to binary - try: - # Remove potential data URL prefix - if ',' in image_base64: - image_base64 = image_base64.split(',', 1)[1] - image = base64.b64decode(image_base64) - except Exception as e: - return web.json_response({"error": f"Invalid base64 image data: {str(e)}"}, status=400) - else: - return web.json_response({"error": "No image data provided"}, status=400) - - # Create recipes directory if it doesn't exist - recipes_dir = self.recipe_scanner.recipes_dir - os.makedirs(recipes_dir, exist_ok=True) - - # Generate UUID for the recipe - import uuid - recipe_id = str(uuid.uuid4()) - - # Optimize the image (resize and convert to WebP) - optimized_image, extension = ExifUtils.optimize_image( - image_data=image, - target_width=CARD_PREVIEW_WIDTH, - format='webp', - quality=85, - preserve_metadata=True - ) - - # Save the optimized image - image_filename = f"{recipe_id}{extension}" - image_path = os.path.join(recipes_dir, image_filename) - with open(image_path, 'wb') as f: - f.write(optimized_image) - - # Create the recipe data structure - current_time = time.time() - - # Format loras data according to the recipe.json format - loras_data = [] - for lora in metadata.get("loras", []): - # Modified: Always include deleted LoRAs in the recipe metadata - # Even if they're marked to be excluded, we still keep their identifying information - # The exclude flag will only be used to determine if they should be included in recipe syntax - - # Convert frontend lora format to recipe format - lora_entry = { - "file_name": lora.get("file_name", "") or os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] if lora.get("localPath") else "", - "hash": lora.get("hash", "").lower() if lora.get("hash") else "", - "strength": float(lora.get("weight", 1.0)), - "modelVersionId": lora.get("id", 0), - "modelName": lora.get("name", ""), - "modelVersionName": lora.get("version", ""), - "isDeleted": lora.get("isDeleted", False), # Preserve deletion status in saved recipe - "exclude": lora.get("exclude", False) # Add exclude flag to the recipe - } - loras_data.append(lora_entry) - - # Format gen_params according to the recipe.json format - gen_params = metadata.get("gen_params", {}) - if not gen_params and "raw_metadata" in metadata: - # Extract from raw metadata if available - raw_metadata = metadata.get("raw_metadata", {}) - gen_params = { - "prompt": raw_metadata.get("prompt", ""), - "negative_prompt": raw_metadata.get("negative_prompt", ""), - "checkpoint": raw_metadata.get("checkpoint", {}), - "steps": raw_metadata.get("steps", ""), - "sampler": raw_metadata.get("sampler", ""), - "cfg_scale": raw_metadata.get("cfg_scale", ""), - "seed": raw_metadata.get("seed", ""), - "size": raw_metadata.get("size", ""), - "clip_skip": raw_metadata.get("clip_skip", "") - } - - # Calculate recipe fingerprint - from ..utils.utils import calculate_recipe_fingerprint - fingerprint = calculate_recipe_fingerprint(loras_data) - - # Create the recipe data structure - recipe_data = { - "id": recipe_id, - "file_path": image_path, - "title": name, - "modified": current_time, - "created_date": current_time, - "base_model": metadata.get("base_model", ""), - "loras": loras_data, - "gen_params": gen_params, - "fingerprint": fingerprint - } - - # Add tags if provided - if tags: - recipe_data["tags"] = tags - - # Add source_path if provided in metadata - if metadata.get("source_path"): - recipe_data["source_path"] = metadata.get("source_path") - - # Save the recipe JSON - json_filename = f"{recipe_id}.recipe.json" - json_path = os.path.join(recipes_dir, json_filename) - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - - # Add recipe metadata to the image - ExifUtils.append_recipe_metadata(image_path, recipe_data) - - # Check for duplicates - matching_recipes = [] - if fingerprint: - matching_recipes = await self.recipe_scanner.find_recipes_by_fingerprint(fingerprint) - # Remove current recipe from matches - if recipe_id in matching_recipes: - matching_recipes.remove(recipe_id) - - # Simplified cache update approach - # Instead of trying to update the cache directly, just set it to None - # to force a refresh on the next get_cached_data call - if self.recipe_scanner._cache is not None: - # Add the recipe to the raw data if the cache exists - # This is a simple direct update without locks or timeouts - self.recipe_scanner._cache.raw_data.append(recipe_data) - # Schedule a background task to resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Added recipe {recipe_id} to cache") - - return web.json_response({ - 'success': True, - 'recipe_id': recipe_id, - 'image_path': image_path, - 'json_path': json_path, - 'matching_recipes': matching_recipes - }) - - except Exception as e: - logger.error(f"Error saving recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def delete_recipe(self, request: web.Request) -> web.Response: - """Delete a recipe by ID""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - recipe_id = request.match_info['recipe_id'] - - # Get recipes directory - recipes_dir = self.recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - return web.json_response({"error": "Recipes directory not found"}, status=404) - - # Find recipe JSON file - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): - return web.json_response({"error": "Recipe not found"}, status=404) - - # Load recipe data to get image path - with open(recipe_json_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - # Get image path - image_path = recipe_data.get('file_path') - - # Delete recipe JSON file - os.remove(recipe_json_path) - logger.info(f"Deleted recipe JSON file: {recipe_json_path}") - - # Delete recipe image if it exists - if image_path and os.path.exists(image_path): - os.remove(image_path) - logger.info(f"Deleted recipe image: {image_path}") - - # Simplified cache update approach - if self.recipe_scanner._cache is not None: - # Remove the recipe from raw_data if it exists - self.recipe_scanner._cache.raw_data = [ - r for r in self.recipe_scanner._cache.raw_data - if str(r.get('id', '')) != recipe_id - ] - # Schedule a background task to resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Removed recipe {recipe_id} from cache") - - return web.json_response({"success": True, "message": "Recipe deleted successfully"}) - except Exception as e: - logger.error(f"Error deleting recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Get top tags used in recipes""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Get limit parameter with default - limit = int(request.query.get('limit', '20')) - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Count tag occurrences - tag_counts = {} - for recipe in cache.raw_data: - if 'tags' in recipe and recipe['tags']: - for tag in recipe['tags']: - tag_counts[tag] = tag_counts.get(tag, 0) + 1 - - # Sort tags by count and limit results - sorted_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.items()] - sorted_tags.sort(key=lambda x: x['count'], reverse=True) - top_tags = sorted_tags[:limit] - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - except Exception as e: - logger.error(f"Error retrieving top tags: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in recipes""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Count base model occurrences - base_model_counts = {} - for recipe in cache.raw_data: - if 'base_model' in recipe and recipe['base_model']: - base_model = recipe['base_model'] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - # Sort base models by count - sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] - sorted_models.sort(key=lambda x: x['count'], reverse=True) - - return web.json_response({ - 'success': True, - 'base_models': sorted_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e)} - , status=500) - - async def share_recipe(self, request: web.Request) -> web.Response: - """Process a recipe image for sharing by adding metadata to EXIF""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - recipe_id = request.match_info['recipe_id'] - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Find the specific recipe - recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None) - - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - # Get the image path - image_path = recipe.get('file_path') - if not image_path or not os.path.exists(image_path): - return web.json_response({"error": "Recipe image not found"}, status=404) - - # Create a temporary copy of the image to modify - import tempfile - import shutil - - # Create temp file with same extension - ext = os.path.splitext(image_path)[1] - with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file: - temp_path = temp_file.name - - # Copy the original image to temp file - shutil.copy2(image_path, temp_path) - processed_path = temp_path - - # Create a URL for the processed image - # Use a timestamp to prevent caching - timestamp = int(time.time()) - url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}" - - # Store the temp path in a dictionary to serve later - if not hasattr(self, '_shared_recipes'): - self._shared_recipes = {} - - self._shared_recipes[recipe_id] = { - 'path': processed_path, - 'timestamp': timestamp, - 'expires': time.time() + 300 # Expire after 5 minutes - } - - # Clean up old entries - self._cleanup_shared_recipes() - - return web.json_response({ - 'success': True, - 'download_url': url_path, - 'filename': f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}{ext}" - }) - except Exception as e: - logger.error(f"Error sharing recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def download_shared_recipe(self, request: web.Request) -> web.Response: - """Serve a processed recipe image for download""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - recipe_id = request.match_info['recipe_id'] - - # Check if we have this shared recipe - if not hasattr(self, '_shared_recipes') or recipe_id not in self._shared_recipes: - return web.json_response({"error": "Shared recipe not found or expired"}, status=404) - - shared_info = self._shared_recipes[recipe_id] - file_path = shared_info['path'] - - if not os.path.exists(file_path): - return web.json_response({"error": "Shared recipe file not found"}, status=404) - - # Get recipe to determine filename - cache = await self.recipe_scanner.get_cached_data() - recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None) - - # Set filename for download - filename = f"recipe_{recipe.get('title', '').replace(' ', '_').lower() if recipe else recipe_id}" - ext = os.path.splitext(file_path)[1] - download_filename = f"{filename}{ext}" - - # Serve the file - return web.FileResponse( - file_path, - headers={ - 'Content-Disposition': f'attachment; filename="{download_filename}"' - } - ) - except Exception as e: - logger.error(f"Error downloading shared recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _cleanup_shared_recipes(self): - """Clean up expired shared recipes""" - if not hasattr(self, '_shared_recipes'): - return - - current_time = time.time() - expired_ids = [rid for rid, info in self._shared_recipes.items() - if current_time > info.get('expires', 0)] - - for rid in expired_ids: - try: - # Delete the temporary file - file_path = self._shared_recipes[rid]['path'] - if os.path.exists(file_path): - os.unlink(file_path) - - # Remove from dictionary - del self._shared_recipes[rid] - except Exception as e: - logger.error(f"Error cleaning up shared recipe {rid}: {e}") - - async def save_recipe_from_widget(self, request: web.Request) -> web.Response: - """Save a recipe from the LoRAs widget""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Get metadata using the metadata collector instead of workflow parsing - raw_metadata = get_metadata() - metadata_dict = MetadataProcessor.to_dict(raw_metadata) - - # Check if we have valid metadata - if not metadata_dict: - return web.json_response({"error": "No generation metadata found"}, status=400) - - # Get the most recent image from metadata registry instead of temp directory - if not standalone_mode: - metadata_registry = MetadataRegistry() - latest_image = metadata_registry.get_first_decoded_image() - else: - latest_image = None - - if latest_image is None: - return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400) - - # Convert the image data to bytes - handle tuple and tensor cases - logger.debug(f"Image type: {type(latest_image)}") - - try: - # Handle the tuple case first - if isinstance(latest_image, tuple): - # Extract the tensor from the tuple - if len(latest_image) > 0: - tensor_image = latest_image[0] - else: - return web.json_response({"error": "Empty image tuple received"}, status=400) - else: - tensor_image = latest_image - - # Get the shape info for debugging - if hasattr(tensor_image, 'shape'): - shape_info = tensor_image.shape - logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}") - - import torch - - # Convert tensor to numpy array - if isinstance(tensor_image, torch.Tensor): - image_np = tensor_image.cpu().numpy() - else: - image_np = np.array(tensor_image) - - # Handle different tensor shapes - # Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch - if len(image_np.shape) > 3: - # Remove batch dimensions until we get to (H, W, 3) - while len(image_np.shape) > 3: - image_np = image_np[0] - - # If values are in [0, 1] range, convert to [0, 255] - if image_np.dtype == np.float32 or image_np.dtype == np.float64: - if image_np.max() <= 1.0: - image_np = (image_np * 255).astype(np.uint8) - - # Ensure image is in the right format (HWC with RGB channels) - if len(image_np.shape) == 3 and image_np.shape[2] == 3: - pil_image = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - pil_image.save(img_byte_arr, format='PNG') - image = img_byte_arr.getvalue() - else: - return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400) - except Exception as e: - logger.error(f"Error processing image data: {str(e)}", exc_info=True) - return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400) - - # Get the lora stack from the metadata - lora_stack = metadata_dict.get("loras", "") - - # Parse the lora stack format: " ..." - import re - lora_matches = re.findall(r']+)>', lora_stack) - - # Check if any loras were found - if not lora_matches: - return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400) - - # Generate recipe name from the first 3 loras (or less if fewer are available) - loras_for_name = lora_matches[:3] # Take at most 3 loras for the name - - recipe_name_parts = [] - for lora_name, lora_strength in loras_for_name: - # Get the basename without path or extension - basename = os.path.basename(lora_name) - basename = os.path.splitext(basename)[0] - recipe_name_parts.append(f"{basename}:{lora_strength}") - - recipe_name = " ".join(recipe_name_parts) - - # Create recipes directory if it doesn't exist - recipes_dir = self.recipe_scanner.recipes_dir - os.makedirs(recipes_dir, exist_ok=True) - - # Generate UUID for the recipe - import uuid - recipe_id = str(uuid.uuid4()) - - # Optimize the image (resize and convert to WebP) - optimized_image, extension = ExifUtils.optimize_image( - image_data=image, - target_width=CARD_PREVIEW_WIDTH, - format='webp', - quality=85, - preserve_metadata=True - ) - - # Save the optimized image - image_filename = f"{recipe_id}{extension}" - image_path = os.path.join(recipes_dir, image_filename) - with open(image_path, 'wb') as f: - f.write(optimized_image) - - # Format loras data from the lora stack - loras_data = [] - - for lora_name, lora_strength in lora_matches: - try: - # Get lora info from scanner - lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name) - - # Create lora entry - lora_entry = { - "file_name": lora_name, - "hash": lora_info.get("sha256", "").lower() if lora_info else "", - "strength": float(lora_strength), - "modelVersionId": lora_info.get("civitai", {}).get("id", 0) if lora_info else 0, - "modelName": lora_info.get("civitai", {}).get("model", {}).get("name", "") if lora_info else lora_name, - "modelVersionName": lora_info.get("civitai", {}).get("name", "") if lora_info else "", - "isDeleted": False - } - loras_data.append(lora_entry) - except Exception as e: - logger.warning(f"Error processing LoRA {lora_name}: {e}") - - # Get base model from lora scanner for the available loras - base_model_counts = {} - for lora in loras_data: - lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", "")) - if lora_info and "base_model" in lora_info: - base_model = lora_info["base_model"] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - # Get most common base model - most_common_base_model = "" - if base_model_counts: - most_common_base_model = max(base_model_counts.items(), key=lambda x: x[1])[0] - - # Create the recipe data structure - recipe_data = { - "id": recipe_id, - "file_path": image_path, - "title": recipe_name, # Use generated recipe name - "modified": time.time(), - "created_date": time.time(), - "base_model": most_common_base_model, - "loras": loras_data, - "checkpoint": metadata_dict.get("checkpoint", ""), - "gen_params": {key: value for key, value in metadata_dict.items() - if key not in ['checkpoint', 'loras']}, - "loras_stack": lora_stack # Include the original lora stack string - } - - # Save the recipe JSON - json_filename = f"{recipe_id}.recipe.json" - json_path = os.path.join(recipes_dir, json_filename) - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - - # Add recipe metadata to the image - ExifUtils.append_recipe_metadata(image_path, recipe_data) - - # Update cache - if self.recipe_scanner._cache is not None: - # Add the recipe to the raw data if the cache exists - self.recipe_scanner._cache.raw_data.append(recipe_data) - # Schedule a background task to resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Added recipe {recipe_id} to cache") - - return web.json_response({ - 'success': True, - 'recipe_id': recipe_id, - 'image_path': image_path, - 'json_path': json_path, - 'recipe_name': recipe_name # Include the generated recipe name in the response - }) - - except Exception as e: - logger.error(f"Error saving recipe from widget: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_recipe_syntax(self, request: web.Request) -> web.Response: - """Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - recipe_id = request.match_info['recipe_id'] - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Find the specific recipe - recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None) - - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - # Get the loras from the recipe - loras = recipe.get('loras', []) - - if not loras: - return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) - - # Generate recipe syntax for all LoRAs that: - # 1. Are in the library (not deleted) OR - # 2. Are deleted but not marked for exclusion - lora_syntax_parts = [] - - # Access the hash_index from lora_scanner - hash_index = self.recipe_scanner._lora_scanner._hash_index - - for lora in loras: - # Skip loras that are deleted AND marked for exclusion - if lora.get("isDeleted", False): - continue - - if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")): - continue - - # Get the strength - strength = lora.get("strength", 1.0) - - # Try to find the actual file name for this lora - file_name = None - hash_value = lora.get("hash", "").lower() - - if hash_value and hasattr(hash_index, "_hash_to_path"): - # Look up the file path from the hash - file_path = hash_index._hash_to_path.get(hash_value) - - if file_path: - # Extract the file name without extension from the path - file_name = os.path.splitext(os.path.basename(file_path))[0] - - # If hash lookup failed, fall back to modelVersionId lookup - if not file_name and lora.get("modelVersionId"): - # Search for files with matching modelVersionId - all_loras = await self.recipe_scanner._lora_scanner.get_cached_data() - for cached_lora in all_loras.raw_data: - if not cached_lora.get("civitai"): - continue - if cached_lora.get("civitai", {}).get("id") == lora.get("modelVersionId"): - file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0] - break - - # If all lookups failed, use the file_name from the recipe - if not file_name: - file_name = lora.get("file_name", "unknown-lora") - - # Add to syntax parts - lora_syntax_parts.append(f"") - - # Join the LoRA syntax parts - lora_syntax = " ".join(lora_syntax_parts) - - return web.json_response({ - 'success': True, - 'syntax': lora_syntax - }) - except Exception as e: - logger.error(f"Error generating recipe syntax: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def update_recipe(self, request: web.Request) -> web.Response: - """Update recipe metadata (name and tags)""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - recipe_id = request.match_info['recipe_id'] - data = await request.json() - - # Validate required fields - if 'title' not in data and 'tags' not in data and 'source_path' not in data and 'preview_nsfw_level' not in data: - return web.json_response({ - "error": "At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)" - }, status=400) - - # Use the recipe scanner's update method - success = await self.recipe_scanner.update_recipe_metadata(recipe_id, data) - - if not success: - return web.json_response({"error": "Recipe not found or update failed"}, status=404) - - return web.json_response({ - "success": True, - "recipe_id": recipe_id, - "updates": data - }) - except Exception as e: - logger.error(f"Error updating recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def reconnect_lora(self, request: web.Request) -> web.Response: - """Reconnect a deleted LoRA in a recipe to a local LoRA file""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Parse request data - data = await request.json() - - # Validate required fields - required_fields = ['recipe_id', 'lora_index', 'target_name'] - for field in required_fields: - if field not in data: - return web.json_response({ - "error": f"Missing required field: {field}" - }, status=400) - - recipe_id = data['recipe_id'] - lora_index = int(data['lora_index']) - target_name = data['target_name'] - - # Get recipe scanner - scanner = self.recipe_scanner - lora_scanner = scanner._lora_scanner - - # Check if recipe exists - recipe_path = os.path.join(scanner.recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_path): - return web.json_response({"error": "Recipe not found"}, status=404) - - # Find target LoRA by name - target_lora = await lora_scanner.get_model_info_by_name(target_name) - if not target_lora: - return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) - - # Load recipe data - with open(recipe_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - lora = recipe_data.get("loras", [])[lora_index] if lora_index < len(recipe_data.get('loras', [])) else None - - if lora is None: - return web.json_response({"error": "LoRA index out of range in recipe"}, status=404) - - # Update LoRA data - lora['isDeleted'] = False - lora['exclude'] = False - lora['file_name'] = target_name - - # Update with information from the target LoRA - if 'sha256' in target_lora: - lora['hash'] = target_lora['sha256'].lower() - if target_lora.get("civitai"): - lora['modelName'] = target_lora['civitai']['model']['name'] - lora['modelVersionName'] = target_lora['civitai']['name'] - lora['modelVersionId'] = target_lora['civitai']['id'] - - updated_lora = dict(lora) # Make a copy for response - - # Recalculate recipe fingerprint after updating LoRA - from ..utils.utils import calculate_recipe_fingerprint - recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', [])) - - # Save updated recipe - with open(recipe_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - - updated_lora['inLibrary'] = True - updated_lora['preview_url'] = config.get_preview_static_url(target_lora['preview_url']) - updated_lora['localPath'] = target_lora['file_path'] - - # Update in cache if it exists - if scanner._cache is not None: - for cache_item in scanner._cache.raw_data: - if cache_item.get('id') == recipe_id: - # Replace loras array with updated version - cache_item['loras'] = recipe_data['loras'] - # Update fingerprint in cache - cache_item['fingerprint'] = recipe_data['fingerprint'] - - # Resort the cache - asyncio.create_task(scanner._cache.resort()) - break - - # Update EXIF metadata if image exists - image_path = recipe_data.get('file_path') - if image_path and os.path.exists(image_path): - from ..utils.exif_utils import ExifUtils - ExifUtils.append_recipe_metadata(image_path, recipe_data) - - # Find other recipes with the same fingerprint - matching_recipes = [] - if 'fingerprint' in recipe_data: - matching_recipes = await scanner.find_recipes_by_fingerprint(recipe_data['fingerprint']) - # Remove current recipe from matches - if recipe_id in matching_recipes: - matching_recipes.remove(recipe_id) - - return web.json_response({ - "success": True, - "recipe_id": recipe_id, - "updated_lora": updated_lora, - "matching_recipes": matching_recipes - }) - - except Exception as e: - logger.error(f"Error reconnecting LoRA: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_recipes_for_lora(self, request: web.Request) -> web.Response: - """Get recipes that use a specific Lora""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - lora_hash = request.query.get('hash') - - # Hash is required - if not lora_hash: - return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400) - - # Log the search parameters - logger.debug(f"Getting recipes for Lora by hash: {lora_hash}") - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Filter recipes that use this Lora by hash - matching_recipes = [] - for recipe in cache.raw_data: - # Check if any of the recipe's loras match this hash - loras = recipe.get('loras', []) - for lora in loras: - if lora.get('hash', '').lower() == lora_hash.lower(): - matching_recipes.append(recipe) - break # No need to check other loras in this recipe - - # Process the recipes similar to get_paginated_data to ensure all needed data is available - for recipe in matching_recipes: - # Add inLibrary information for each lora - if 'loras' in recipe: - for lora in recipe['loras']: - if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower()) - lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower()) - - # Ensure file_url is set (needed by frontend) - if 'file_path' in recipe: - recipe['file_url'] = self._format_recipe_file_url(recipe['file_path']) - else: - recipe['file_url'] = '/loras_static/images/no-preview.png' - - return web.json_response({'success': True, 'recipes': matching_recipes}) - except Exception as e: - logger.error(f"Error getting recipes for Lora: {str(e)}") - return web.json_response({'success': False, 'error': str(e)}, status=500) - - async def scan_recipes(self, request: web.Request) -> web.Response: - """API endpoint for scanning and rebuilding the recipe cache""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Force refresh the recipe cache - logger.info("Manually triggering recipe cache rebuild") - await self.recipe_scanner.get_cached_data(force_refresh=True) - - return web.json_response({ - 'success': True, - 'message': 'Recipe cache refreshed successfully' - }) - except Exception as e: - logger.error(f"Error refreshing recipe cache: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def find_duplicates(self, request: web.Request) -> web.Response: - """Find all duplicate recipes based on fingerprints""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Get all duplicate recipes - duplicate_groups = await self.recipe_scanner.find_all_duplicate_recipes() - - # Create response data with additional recipe information - response_data = [] - - for fingerprint, recipe_ids in duplicate_groups.items(): - # Skip groups with only one recipe (not duplicates) - if len(recipe_ids) <= 1: - continue - - # Get recipe details for each recipe in the group - recipes = [] - for recipe_id in recipe_ids: - recipe = await self.recipe_scanner.get_recipe_by_id(recipe_id) - if recipe: - # Add only needed fields to keep response size manageable - recipes.append({ - 'id': recipe.get('id'), - 'title': recipe.get('title'), - 'file_url': recipe.get('file_url') or self._format_recipe_file_url(recipe.get('file_path', '')), - 'modified': recipe.get('modified'), - 'created_date': recipe.get('created_date'), - 'lora_count': len(recipe.get('loras', [])), - }) - - # Only include groups with at least 2 valid recipes - if len(recipes) >= 2: - # Sort recipes by modified date (newest first) - recipes.sort(key=lambda x: x.get('modified', 0), reverse=True) - - response_data.append({ - 'fingerprint': fingerprint, - 'count': len(recipes), - 'recipes': recipes - }) - - # Sort groups by count (highest first) - response_data.sort(key=lambda x: x['count'], reverse=True) - - return web.json_response({ - 'success': True, - 'duplicate_groups': response_data - }) - - except Exception as e: - logger.error(f"Error finding duplicate recipes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def bulk_delete(self, request: web.Request) -> web.Response: - """Delete multiple recipes by ID""" - try: - # Ensure services are initialized - await self.ensure_dependencies_ready() - - # Parse request data - data = await request.json() - recipe_ids = data.get('recipe_ids', []) - - if not recipe_ids: - return web.json_response({ - 'success': False, - 'error': 'No recipe IDs provided' - }, status=400) - - # Get recipes directory - recipes_dir = self.recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - return web.json_response({ - 'success': False, - 'error': 'Recipes directory not found' - }, status=404) - - # Track deleted and failed recipes - deleted_recipes = [] - failed_recipes = [] - - # Process each recipe ID - for recipe_id in recipe_ids: - # Find recipe JSON file - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - - if not os.path.exists(recipe_json_path): - failed_recipes.append({ - 'id': recipe_id, - 'reason': 'Recipe not found' - }) - continue - - try: - # Load recipe data to get image path - with open(recipe_json_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - # Get image path - image_path = recipe_data.get('file_path') - - # Delete recipe JSON file - os.remove(recipe_json_path) - - # Delete recipe image if it exists - if image_path and os.path.exists(image_path): - os.remove(image_path) - - deleted_recipes.append(recipe_id) - - except Exception as e: - failed_recipes.append({ - 'id': recipe_id, - 'reason': str(e) - }) - - # Update cache if any recipes were deleted - if deleted_recipes and self.recipe_scanner._cache is not None: - # Remove deleted recipes from raw_data - self.recipe_scanner._cache.raw_data = [ - r for r in self.recipe_scanner._cache.raw_data - if r.get('id') not in deleted_recipes - ] - # Resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Removed {len(deleted_recipes)} recipes from cache") - - return web.json_response({ - 'success': True, - 'deleted': deleted_recipes, - 'failed': failed_recipes, - 'total_deleted': len(deleted_recipes), - 'total_failed': len(failed_recipes) - }) - - except Exception as e: - logger.error(f"Error performing bulk delete: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - # Legacy method aliases retained for compatibility with existing imports. - handle_recipes_page = render_page - get_recipes = list_recipes - get_recipe_detail = get_recipe - analyze_recipe_image = analyze_uploaded_image diff --git a/tests/routes/test_recipe_route_scaffolding.py b/tests/routes/test_recipe_route_scaffolding.py index 1f0e723d..59765d36 100644 --- a/tests/routes/test_recipe_route_scaffolding.py +++ b/tests/routes/test_recipe_route_scaffolding.py @@ -128,31 +128,38 @@ def test_register_startup_hooks_appends_once(): assert len(startup_bound_to_routes) == 2 -def test_to_route_mapping_uses_handler_owner(monkeypatch: pytest.MonkeyPatch): - class DummyOwner: - async def render_page(self, request): - return web.Response(text="ok") +def test_to_route_mapping_uses_handler_set(): + class DummyHandlerSet: + def __init__(self): + self.calls = 0 - async def list_recipes(self, request): # pragma: no cover - invoked via mapping - return web.json_response({}) + def to_route_mapping(self): + self.calls += 1 + + async def render_page(request): # pragma: no cover - simple coroutine + return web.Response(text="ok") + + return {"render_page": render_page} class DummyRoutes(base_routes_module.BaseRecipeRoutes): - def get_handler_owner(self): # noqa: D401 - simple override for test - return DummyOwner() + def __init__(self): + super().__init__() + self.created = 0 - monkeypatch.setattr( - base_routes_module.BaseRecipeRoutes, - "_HANDLER_NAMES", - ("render_page", "list_recipes"), - ) + def _create_handler_set(self): # noqa: D401 - simple override for test + self.created += 1 + return DummyHandlerSet() routes = DummyRoutes() mapping = routes.to_route_mapping() - assert set(mapping.keys()) == {"render_page", "list_recipes"} + assert set(mapping.keys()) == {"render_page"} assert asyncio.iscoroutinefunction(mapping["render_page"]) # Cached mapping reused on subsequent calls assert routes.to_route_mapping() is mapping + # Handler set cached for get_handler_owner callers + assert isinstance(routes.get_handler_owner(), DummyHandlerSet) + assert routes.created == 1 def test_recipe_route_registrar_binds_every_route(): From 097a68ad1862e26cf7f423d19569c8a126e5a64d Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 13:25:21 +0800 Subject: [PATCH 12/24] refactor(recipes): introduce dedicated services for handlers --- py/routes/base_recipe_routes.py | 45 + py/routes/handlers/recipe_handlers.py | 905 ++++----------------- py/services/recipes/__init__.py | 23 + py/services/recipes/analysis_service.py | 289 +++++++ py/services/recipes/errors.py | 22 + py/services/recipes/persistence_service.py | 467 +++++++++++ py/services/recipes/sharing_service.py | 113 +++ tests/services/test_recipe_services.py | 146 ++++ 8 files changed, 1274 insertions(+), 736 deletions(-) create mode 100644 py/services/recipes/__init__.py create mode 100644 py/services/recipes/analysis_service.py create mode 100644 py/services/recipes/errors.py create mode 100644 py/services/recipes/persistence_service.py create mode 100644 py/services/recipes/sharing_service.py create mode 100644 tests/services/test_recipe_services.py diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py index 59d4e7ec..4447bb7b 100644 --- a/py/routes/base_recipe_routes.py +++ b/py/routes/base_recipe_routes.py @@ -2,15 +2,25 @@ from __future__ import annotations import logging +import os from typing import Callable, Mapping import jinja2 from aiohttp import web from ..config import config +from ..recipes import RecipeParserFactory +from ..services.downloader import get_downloader +from ..services.recipes import ( + RecipeAnalysisService, + RecipePersistenceService, + RecipeSharingService, +) from ..services.server_i18n import server_i18n from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings +from ..utils.constants import CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils from .handlers.recipe_handlers import ( RecipeAnalysisHandler, RecipeHandlerSet, @@ -124,6 +134,37 @@ class BaseRecipeRoutes: recipe_scanner_getter = lambda: self.recipe_scanner civitai_client_getter = lambda: self.civitai_client + standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" + if not standalone_mode: + from ..metadata_collector import get_metadata # type: ignore[import-not-found] + from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found] + MetadataProcessor, + ) + from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found] + MetadataRegistry, + ) + else: # pragma: no cover - optional dependency path + get_metadata = None # type: ignore[assignment] + MetadataProcessor = None # type: ignore[assignment] + MetadataRegistry = None # type: ignore[assignment] + + analysis_service = RecipeAnalysisService( + exif_utils=ExifUtils, + recipe_parser_factory=RecipeParserFactory, + downloader_factory=get_downloader, + metadata_collector=get_metadata, + metadata_processor_cls=MetadataProcessor, + metadata_registry_cls=MetadataRegistry, + standalone_mode=standalone_mode, + logger=logger, + ) + persistence_service = RecipePersistenceService( + exif_utils=ExifUtils, + card_preview_width=CARD_PREVIEW_WIDTH, + logger=logger, + ) + sharing_service = RecipeSharingService(logger=logger) + page_view = RecipePageView( ensure_dependencies_ready=self.ensure_dependencies_ready, settings_service=self.settings, @@ -148,17 +189,21 @@ class BaseRecipeRoutes: ensure_dependencies_ready=self.ensure_dependencies_ready, recipe_scanner_getter=recipe_scanner_getter, logger=logger, + persistence_service=persistence_service, + analysis_service=analysis_service, ) analysis = RecipeAnalysisHandler( ensure_dependencies_ready=self.ensure_dependencies_ready, recipe_scanner_getter=recipe_scanner_getter, civitai_client_getter=civitai_client_getter, logger=logger, + analysis_service=analysis_service, ) sharing = RecipeSharingHandler( ensure_dependencies_ready=self.ensure_dependencies_ready, recipe_scanner_getter=recipe_scanner_getter, logger=logger, + sharing_service=sharing_service, ) return RecipeHandlerSet( diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 8d8f96bf..35f4c088 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -1,40 +1,25 @@ """Dedicated handler objects for recipe-related routes.""" from __future__ import annotations -import asyncio -import base64 -import io import json import logging import os -import tempfile -import time from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, Mapping, Optional -import numpy as np from aiohttp import web -from PIL import Image from ...config import config -from ...recipes import RecipeParserFactory -from ...services.downloader import get_downloader from ...services.server_i18n import server_i18n as default_server_i18n from ...services.settings_manager import SettingsManager -from ...utils.constants import CARD_PREVIEW_WIDTH -from ...utils.exif_utils import ExifUtils - -# Check if running in standalone mode -standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" - -if not standalone_mode: - from ...metadata_collector import get_metadata - from ...metadata_collector.metadata_processor import MetadataProcessor - from ...metadata_collector.metadata_registry import MetadataRegistry -else: # pragma: no cover - optional dependency path - get_metadata = None # type: ignore[assignment] - MetadataProcessor = None # type: ignore[assignment] - MetadataRegistry = None # type: ignore[assignment] +from ...services.recipes import ( + RecipeAnalysisService, + RecipeDownloadError, + RecipeNotFoundError, + RecipePersistenceService, + RecipeSharingService, + RecipeValidationError, +) Logger = logging.Logger EnsureDependenciesCallable = Callable[[], Awaitable[None]] @@ -457,22 +442,14 @@ class RecipeManagementHandler: ensure_dependencies_ready: EnsureDependenciesCallable, recipe_scanner_getter: RecipeScannerGetter, logger: Logger, - exif_utils=ExifUtils, - card_preview_width: int = CARD_PREVIEW_WIDTH, - metadata_collector: Optional[Callable[[], Any]] = get_metadata, - metadata_processor_cls: Optional[type] = MetadataProcessor, - metadata_registry_cls: Optional[type] = MetadataRegistry, - standalone_mode: bool = standalone_mode, + persistence_service: RecipePersistenceService, + analysis_service: RecipeAnalysisService, ) -> None: self._ensure_dependencies_ready = ensure_dependencies_ready self._recipe_scanner_getter = recipe_scanner_getter self._logger = logger - self._exif_utils = exif_utils - self._card_preview_width = card_preview_width - self._metadata_collector = metadata_collector - self._metadata_processor_cls = metadata_processor_cls - self._metadata_registry_cls = metadata_registry_cls - self._standalone_mode = standalone_mode + self._persistence_service = persistence_service + self._analysis_service = analysis_service async def save_recipe(self, request: web.Request) -> web.Response: try: @@ -482,171 +459,19 @@ class RecipeManagementHandler: raise RuntimeError("Recipe scanner unavailable") reader = await request.multipart() + payload = await self._parse_save_payload(reader) - image: Optional[bytes] = None - image_base64: Optional[str] = None - name: Optional[str] = None - tags: list[str] = [] - metadata: Dict[str, Any] | None = None - - while True: - field = await reader.next() - if field is None: - break - - if field.name == "image": - image_chunks = bytearray() - while True: - chunk = await field.read_chunk() - if not chunk: - break - image_chunks.extend(chunk) - image = bytes(image_chunks) - elif field.name == "image_base64": - image_base64 = await field.text() - elif field.name == "name": - name = await field.text() - elif field.name == "tags": - tags_text = await field.text() - try: - parsed_tags = json.loads(tags_text) - tags = parsed_tags if isinstance(parsed_tags, list) else [] - except Exception: - tags = [] - elif field.name == "metadata": - metadata_text = await field.text() - try: - metadata = json.loads(metadata_text) - except Exception: - metadata = {} - - missing_fields = [] - if not name: - missing_fields.append("name") - if not metadata: - missing_fields.append("metadata") - if missing_fields: - return web.json_response( - {"error": f"Missing required fields: {', '.join(missing_fields)}"}, - status=400, - ) - - if image is None: - if image_base64: - try: - if "," in image_base64: - image_base64 = image_base64.split(",", 1)[1] - image = base64.b64decode(image_base64) - except Exception as exc: - return web.json_response({"error": f"Invalid base64 image data: {exc}"}, status=400) - else: - return web.json_response({"error": "No image data provided"}, status=400) - - recipes_dir = recipe_scanner.recipes_dir - os.makedirs(recipes_dir, exist_ok=True) - - import uuid - - recipe_id = str(uuid.uuid4()) - optimized_image, extension = self._exif_utils.optimize_image( - image_data=image, - target_width=self._card_preview_width, - format="webp", - quality=85, - preserve_metadata=True, - ) - - image_filename = f"{recipe_id}{extension}" - image_path = os.path.join(recipes_dir, image_filename) - with open(image_path, "wb") as file_obj: - file_obj.write(optimized_image) - - current_time = time.time() - loras_data = [] - for lora in metadata.get("loras", []): - loras_data.append( - { - "file_name": lora.get("file_name", "") - or ( - os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] - if lora.get("localPath") - else "" - ), - "hash": (lora.get("hash") or "").lower(), - "strength": float(lora.get("weight", 1.0)), - "modelVersionId": lora.get("id", 0), - "modelName": lora.get("name", ""), - "modelVersionName": lora.get("version", ""), - "isDeleted": lora.get("isDeleted", False), - "exclude": lora.get("exclude", False), - } - ) - - gen_params = metadata.get("gen_params", {}) - if not gen_params and "raw_metadata" in metadata: - raw_metadata = metadata.get("raw_metadata", {}) - gen_params = { - "prompt": raw_metadata.get("prompt", ""), - "negative_prompt": raw_metadata.get("negative_prompt", ""), - "checkpoint": raw_metadata.get("checkpoint", {}), - "steps": raw_metadata.get("steps", ""), - "sampler": raw_metadata.get("sampler", ""), - "cfg_scale": raw_metadata.get("cfg_scale", ""), - "seed": raw_metadata.get("seed", ""), - "size": raw_metadata.get("size", ""), - "clip_skip": raw_metadata.get("clip_skip", ""), - } - - from ...utils.utils import calculate_recipe_fingerprint - - fingerprint = calculate_recipe_fingerprint(loras_data) - - recipe_data = { - "id": recipe_id, - "file_path": image_path, - "title": name, - "modified": current_time, - "created_date": current_time, - "base_model": metadata.get("base_model", ""), - "loras": loras_data, - "gen_params": gen_params, - "fingerprint": fingerprint, - } - - if tags: - recipe_data["tags"] = tags - - if metadata.get("source_path"): - recipe_data["source_path"] = metadata.get("source_path") - - json_filename = f"{recipe_id}.recipe.json" - json_path = os.path.join(recipes_dir, json_filename) - with open(json_path, "w", encoding="utf-8") as file_obj: - json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) - - self._exif_utils.append_recipe_metadata(image_path, recipe_data) - - matching_recipes = [] - if fingerprint: - matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) - if recipe_id in matching_recipes: - matching_recipes.remove(recipe_id) - - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data.append(recipe_data) - asyncio.create_task(cache.resort()) - self._logger.info("Added recipe %s to cache", recipe_id) - - return web.json_response( - { - "success": True, - "recipe_id": recipe_id, - "image_path": image_path, - "json_path": json_path, - "matching_recipes": matching_recipes, - } + result = await self._persistence_service.save_recipe( + recipe_scanner=recipe_scanner, + image_bytes=payload["image_bytes"], + image_base64=payload["image_base64"], + name=payload["name"], + tags=payload["tags"], + metadata=payload["metadata"], ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) except Exception as exc: self._logger.error("Error saving recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) @@ -659,34 +484,12 @@ class RecipeManagementHandler: raise RuntimeError("Recipe scanner unavailable") recipe_id = request.match_info["recipe_id"] - recipes_dir = recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - return web.json_response({"error": "Recipes directory not found"}, status=404) - - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): - return web.json_response({"error": "Recipe not found"}, status=404) - - with open(recipe_json_path, "r", encoding="utf-8") as file_obj: - recipe_data = json.load(file_obj) - - image_path = recipe_data.get("file_path") - os.remove(recipe_json_path) - self._logger.info("Deleted recipe JSON file: %s", recipe_json_path) - - if image_path and os.path.exists(image_path): - os.remove(image_path) - self._logger.info("Deleted recipe image: %s", image_path) - - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data = [ - item for item in cache.raw_data if str(item.get("id", "")) != recipe_id - ] - asyncio.create_task(cache.resort()) - self._logger.info("Removed recipe %s from cache", recipe_id) - - return web.json_response({"success": True, "message": "Recipe deleted successfully"}) + result = await self._persistence_service.delete_recipe( + recipe_scanner=recipe_scanner, recipe_id=recipe_id + ) + return web.json_response(result.payload, status=result.status) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) except Exception as exc: self._logger.error("Error deleting recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) @@ -700,25 +503,14 @@ class RecipeManagementHandler: recipe_id = request.match_info["recipe_id"] data = await request.json() - - if not any( - key in data for key in ("title", "tags", "source_path", "preview_nsfw_level") - ): - return web.json_response( - { - "error": ( - "At least one field to update must be provided (title or tags or " - "source_path or preview_nsfw_level)" - ) - }, - status=400, - ) - - success = await recipe_scanner.update_recipe_metadata(recipe_id, data) - if not success: - return web.json_response({"error": "Recipe not found or update failed"}, status=404) - - return web.json_response({"success": True, "recipe_id": recipe_id, "updates": data}) + result = await self._persistence_service.update_recipe( + recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) except Exception as exc: self._logger.error("Error updating recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) @@ -731,81 +523,21 @@ class RecipeManagementHandler: raise RuntimeError("Recipe scanner unavailable") data = await request.json() - required_fields = ["recipe_id", "lora_index", "target_name"] - for field in required_fields: + for field in ("recipe_id", "lora_index", "target_name"): if field not in data: - return web.json_response({"error": f"Missing required field: {field}"}, status=400) + raise RecipeValidationError(f"Missing required field: {field}") - recipe_id = data["recipe_id"] - lora_index = int(data["lora_index"]) - target_name = data["target_name"] - - recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_path): - return web.json_response({"error": "Recipe not found"}, status=404) - - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - target_lora = None if lora_scanner is None else await lora_scanner.get_model_info_by_name(target_name) - if not target_lora: - return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) - - with open(recipe_path, "r", encoding="utf-8") as file_obj: - recipe_data = json.load(file_obj) - - loras = recipe_data.get("loras", []) - lora = loras[lora_index] if lora_index < len(loras) else None - if lora is None: - return web.json_response({"error": "LoRA index out of range in recipe"}, status=404) - - lora["isDeleted"] = False - lora["exclude"] = False - lora["file_name"] = target_name - if "sha256" in target_lora: - lora["hash"] = target_lora["sha256"].lower() - if target_lora.get("civitai"): - lora["modelName"] = target_lora["civitai"]["model"]["name"] - lora["modelVersionName"] = target_lora["civitai"]["name"] - lora["modelVersionId"] = target_lora["civitai"]["id"] - - from ...utils.utils import calculate_recipe_fingerprint - - recipe_data["fingerprint"] = calculate_recipe_fingerprint(recipe_data.get("loras", [])) - - with open(recipe_path, "w", encoding="utf-8") as file_obj: - json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) - - updated_lora = dict(lora) - updated_lora["inLibrary"] = True - updated_lora["preview_url"] = config.get_preview_static_url(target_lora["preview_url"]) - updated_lora["localPath"] = target_lora["file_path"] - - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - for cache_item in cache.raw_data: - if cache_item.get("id") == recipe_id: - cache_item["loras"] = recipe_data["loras"] - cache_item["fingerprint"] = recipe_data["fingerprint"] - asyncio.create_task(cache.resort()) - break - - image_path = recipe_data.get("file_path") - if image_path and os.path.exists(image_path): - self._exif_utils.append_recipe_metadata(image_path, recipe_data) - - matching_recipes = [] - if "fingerprint" in recipe_data: - matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"]) - if recipe_id in matching_recipes: - matching_recipes.remove(recipe_id) - - return web.json_response( - { - "success": True, - "recipe_id": recipe_id, - "updated_lora": updated_lora, - "matching_recipes": matching_recipes, - } + result = await self._persistence_service.reconnect_lora( + recipe_scanner=recipe_scanner, + recipe_id=data["recipe_id"], + lora_index=int(data["lora_index"]), + target_name=data["target_name"], ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) except Exception as exc: self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) @@ -819,54 +551,14 @@ class RecipeManagementHandler: data = await request.json() recipe_ids = data.get("recipe_ids", []) - if not recipe_ids: - return web.json_response( - {"success": False, "error": "No recipe IDs provided"}, - status=400, - ) - - recipes_dir = recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - return web.json_response( - {"success": False, "error": "Recipes directory not found"}, - status=404, - ) - - deleted_recipes: list[str] = [] - failed_recipes: list[Dict[str, Any]] = [] - - for recipe_id in recipe_ids: - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): - failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"}) - continue - - try: - with open(recipe_json_path, "r", encoding="utf-8") as file_obj: - recipe_data = json.load(file_obj) - image_path = recipe_data.get("file_path") - os.remove(recipe_json_path) - if image_path and os.path.exists(image_path): - os.remove(image_path) - deleted_recipes.append(recipe_id) - except Exception as exc: - failed_recipes.append({"id": recipe_id, "reason": str(exc)}) - - cache = getattr(recipe_scanner, "_cache", None) - if deleted_recipes and cache is not None: - cache.raw_data = [item for item in cache.raw_data if item.get("id") not in deleted_recipes] - asyncio.create_task(cache.resort()) - self._logger.info("Removed %s recipes from cache", len(deleted_recipes)) - - return web.json_response( - { - "success": True, - "deleted": deleted_recipes, - "failed": failed_recipes, - "total_deleted": len(deleted_recipes), - "total_failed": len(failed_recipes), - } + result = await self._persistence_service.bulk_delete( + recipe_scanner=recipe_scanner, recipe_ids=recipe_ids ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=404) except Exception as exc: self._logger.error("Error performing bulk delete: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -878,164 +570,71 @@ class RecipeManagementHandler: if recipe_scanner is None: raise RuntimeError("Recipe scanner unavailable") - if self._metadata_collector is None or self._metadata_processor_cls is None: - return web.json_response({"error": "Metadata collection not available"}, status=400) - - raw_metadata = self._metadata_collector() - metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata) - if not metadata_dict: - return web.json_response({"error": "No generation metadata found"}, status=400) - - if not self._standalone_mode and self._metadata_registry_cls is not None: - metadata_registry = self._metadata_registry_cls() - latest_image = metadata_registry.get_first_decoded_image() - else: - latest_image = None - - if latest_image is None: - return web.json_response( - {"error": "No recent images found to use for recipe. Try generating an image first."}, - status=400, - ) - - self._logger.debug("Image type: %s", type(latest_image)) - - try: - if isinstance(latest_image, tuple): - tensor_image = latest_image[0] if latest_image else None - if tensor_image is None: - return web.json_response({"error": "Empty image tuple received"}, status=400) - else: - tensor_image = latest_image - - if hasattr(tensor_image, "shape"): - shape_info = tensor_image.shape - self._logger.debug("Tensor shape: %s, dtype: %s", shape_info, tensor_image.dtype) - - import torch # type: ignore[import-not-found] - - if isinstance(tensor_image, torch.Tensor): - image_np = tensor_image.cpu().numpy() - else: - image_np = np.array(tensor_image) - - while len(image_np.shape) > 3: - image_np = image_np[0] - - if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0: - image_np = (image_np * 255).astype(np.uint8) - - if len(image_np.shape) == 3 and image_np.shape[2] == 3: - pil_image = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - pil_image.save(img_byte_arr, format="PNG") - image_bytes = img_byte_arr.getvalue() - else: - return web.json_response( - {"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, - status=400, - ) - except Exception as exc: - self._logger.error("Error processing image data: %s", exc, exc_info=True) - return web.json_response({"error": f"Error processing image: {exc}"}, status=400) - - lora_stack = metadata_dict.get("loras", "") - import re - - lora_matches = re.findall(r"]+)>", lora_stack) - if not lora_matches: - return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400) - - loras_for_name = lora_matches[:3] - recipe_name_parts = [] - for name, strength in loras_for_name: - recipe_name_parts.append(f"{name.strip()}-{float(strength):.2f}") - recipe_name = "_".join(recipe_name_parts) - - recipe_name = recipe_name or "recipe" - - recipes_dir = recipe_scanner.recipes_dir - os.makedirs(recipes_dir, exist_ok=True) - - import uuid - - recipe_id = str(uuid.uuid4()) - image_filename = f"{recipe_id}.png" - image_path = os.path.join(recipes_dir, image_filename) - with open(image_path, "wb") as file_obj: - file_obj.write(image_bytes) - - loras_data = [] - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - base_model_counts: Dict[str, int] = {} - - for name, strength in lora_matches: - lora_info = None - if lora_scanner is not None: - lora_info = await lora_scanner.get_model_info_by_name(name) - lora_data = { - "file_name": name, - "strength": float(strength), - "hash": (lora_info.get("sha256") or "").lower() if lora_info else "", - "modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0, - "modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "", - "modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "", - "isDeleted": False, - "exclude": False, - } - loras_data.append(lora_data) - - if lora_info and "base_model" in lora_info: - base_model = lora_info["base_model"] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - most_common_base_model = "" - if base_model_counts: - most_common_base_model = max(base_model_counts.items(), key=lambda item: item[1])[0] - - recipe_data = { - "id": recipe_id, - "file_path": image_path, - "title": recipe_name, - "modified": time.time(), - "created_date": time.time(), - "base_model": most_common_base_model, - "loras": loras_data, - "checkpoint": metadata_dict.get("checkpoint", ""), - "gen_params": { - key: value - for key, value in metadata_dict.items() - if key not in ["checkpoint", "loras"] - }, - "loras_stack": lora_stack, - } - - json_filename = f"{recipe_id}.recipe.json" - json_path = os.path.join(recipes_dir, json_filename) - with open(json_path, "w", encoding="utf-8") as file_obj: - json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) - - self._exif_utils.append_recipe_metadata(image_path, recipe_data) - - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data.append(recipe_data) - asyncio.create_task(cache.resort()) - self._logger.info("Added recipe %s to cache", recipe_id) - - return web.json_response( - { - "success": True, - "recipe_id": recipe_id, - "image_path": image_path, - "json_path": json_path, - "recipe_name": recipe_name, - } + analysis = await self._analysis_service.analyze_widget_metadata( + recipe_scanner=recipe_scanner ) + metadata = analysis.payload.get("metadata") + image_bytes = analysis.payload.get("image_bytes") + if not metadata or image_bytes is None: + raise RecipeValidationError("Unable to extract metadata from widget") + + result = await self._persistence_service.save_recipe_from_widget( + recipe_scanner=recipe_scanner, + metadata=metadata, + image_bytes=image_bytes, + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) except Exception as exc: self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) + async def _parse_save_payload(self, reader) -> dict[str, Any]: + image_bytes: Optional[bytes] = None + image_base64: Optional[str] = None + name: Optional[str] = None + tags: list[str] = [] + metadata: Optional[Dict[str, Any]] = None + + while True: + field = await reader.next() + if field is None: + break + if field.name == "image": + image_chunks = bytearray() + while True: + chunk = await field.read_chunk() + if not chunk: + break + image_chunks.extend(chunk) + image_bytes = bytes(image_chunks) + elif field.name == "image_base64": + image_base64 = await field.text() + elif field.name == "name": + name = await field.text() + elif field.name == "tags": + tags_text = await field.text() + try: + parsed_tags = json.loads(tags_text) + tags = parsed_tags if isinstance(parsed_tags, list) else [] + except Exception: + tags = [] + elif field.name == "metadata": + metadata_text = await field.text() + try: + metadata = json.loads(metadata_text) + except Exception: + metadata = {} + + return { + "image_bytes": image_bytes, + "image_base64": image_base64, + "name": name, + "tags": tags, + "metadata": metadata, + } + class RecipeAnalysisHandler: """Analyze images to extract recipe metadata.""" @@ -1047,20 +646,15 @@ class RecipeAnalysisHandler: recipe_scanner_getter: RecipeScannerGetter, civitai_client_getter: CivitaiClientGetter, logger: Logger, - exif_utils=ExifUtils, - recipe_parser_factory=RecipeParserFactory, - downloader_factory=get_downloader, + analysis_service: RecipeAnalysisService, ) -> None: self._ensure_dependencies_ready = ensure_dependencies_ready self._recipe_scanner_getter = recipe_scanner_getter self._civitai_client_getter = civitai_client_getter self._logger = logger - self._exif_utils = exif_utils - self._recipe_parser_factory = recipe_parser_factory - self._downloader_factory = downloader_factory + self._analysis_service = analysis_service async def analyze_uploaded_image(self, request: web.Request) -> web.Response: - temp_path: Optional[str] = None try: await self._ensure_dependencies_ready() recipe_scanner = self._recipe_scanner_getter() @@ -1069,112 +663,42 @@ class RecipeAnalysisHandler: raise RuntimeError("Required services unavailable") content_type = request.headers.get("Content-Type", "") - is_url_mode = False - metadata: Optional[Dict[str, Any]] = None - if "multipart/form-data" in content_type: reader = await request.multipart() field = await reader.next() if field is None or field.name != "image": - return web.json_response({"error": "No image field found", "loras": []}, status=400) + raise RecipeValidationError("No image field found") + image_chunks = bytearray() + while True: + chunk = await field.read_chunk() + if not chunk: + break + image_chunks.extend(chunk) + result = await self._analysis_service.analyze_uploaded_image( + image_bytes=bytes(image_chunks), + recipe_scanner=recipe_scanner, + ) + return web.json_response(result.payload, status=result.status) - with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: - while True: - chunk = await field.read_chunk() - if not chunk: - break - temp_file.write(chunk) - temp_path = temp_file.name - elif "application/json" in content_type: + if "application/json" in content_type: data = await request.json() - url = data.get("url") - is_url_mode = True - if not url: - return web.json_response({"error": "No URL provided", "loras": []}, status=400) + result = await self._analysis_service.analyze_remote_image( + url=data.get("url"), + recipe_scanner=recipe_scanner, + civitai_client=civitai_client, + ) + return web.json_response(result.payload, status=result.status) - import re - - civitai_image_match = re.match(r"https://civitai\.com/images/(\d+)", url) - if civitai_image_match: - image_id = civitai_image_match.group(1) - image_info = await civitai_client.get_image_info(image_id) - if not image_info: - return web.json_response( - {"error": "Failed to fetch image information from Civitai", "loras": []}, - status=400, - ) - image_url = image_info.get("url") - if not image_url: - return web.json_response( - {"error": "No image URL found in Civitai response", "loras": []}, - status=400, - ) - - downloader = await self._downloader_factory() - with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: - temp_path = temp_file.name - - success, result = await downloader.download_file( - image_url, - temp_path, - use_auth=False, - ) - if not success: - return web.json_response( - {"error": f"Failed to download image from URL: {result}", "loras": []}, - status=400, - ) - metadata = image_info.get("meta") if "meta" in image_info else None - else: - return web.json_response({"error": "Unsupported content type", "loras": []}, status=400) - - if metadata is None and temp_path: - metadata = self._exif_utils.extract_image_metadata(temp_path) - - if not metadata: - response: Dict[str, Any] = {"error": "No metadata found in this image", "loras": []} - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - response["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") - return web.json_response(response, status=200) - - parser = self._recipe_parser_factory.create_parser(metadata) - if parser is None: - response = {"error": "No parser found for this image", "loras": []} - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - response["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") - return web.json_response(response, status=200) - - result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner) - - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") - - if "error" in result and not result.get("loras"): - return web.json_response(result, status=200) - - from ...utils.utils import calculate_recipe_fingerprint - - fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) - result["fingerprint"] = fingerprint - - matching_recipes = [] - if fingerprint: - matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) - - result["matching_recipes"] = matching_recipes - return web.json_response(result) + raise RecipeValidationError("Unsupported content type") + except RecipeValidationError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=400) + except RecipeDownloadError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=404) except Exception as exc: self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True) return web.json_response({"error": str(exc), "loras": []}, status=500) - finally: - if temp_path and os.path.exists(temp_path): - try: - os.unlink(temp_path) - except Exception as cleanup_exc: # pragma: no cover - logging path - self._logger.error("Error deleting temporary file: %s", cleanup_exc) async def analyze_local_image(self, request: web.Request) -> web.Response: try: @@ -1184,50 +708,15 @@ class RecipeAnalysisHandler: raise RuntimeError("Recipe scanner unavailable") data = await request.json() - file_path = data.get("path") - if not file_path: - return web.json_response({"error": "No file path provided", "loras": []}, status=400) - - file_path = os.path.normpath(file_path.strip('"').strip("'")) - if not os.path.isfile(file_path): - return web.json_response({"error": "File not found", "loras": []}, status=404) - - metadata = self._exif_utils.extract_image_metadata(file_path) - if not metadata: - with open(file_path, "rb") as image_file: - image_base64 = base64.b64encode(image_file.read()).decode("utf-8") - return web.json_response( - {"error": "No metadata found in this image", "loras": [], "image_base64": image_base64}, - status=200, - ) - - parser = self._recipe_parser_factory.create_parser(metadata) - if parser is None: - with open(file_path, "rb") as image_file: - image_base64 = base64.b64encode(image_file.read()).decode("utf-8") - return web.json_response( - {"error": "No parser found for this image", "loras": [], "image_base64": image_base64}, - status=200, - ) - - result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner) - with open(file_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode("utf-8") - - if "error" in result and not result.get("loras"): - return web.json_response(result, status=200) - - from ...utils.utils import calculate_recipe_fingerprint - - fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) - result["fingerprint"] = fingerprint - - matching_recipes = [] - if fingerprint: - matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) - result["matching_recipes"] = matching_recipes - - return web.json_response(result) + result = await self._analysis_service.analyze_local_image( + file_path=data.get("path"), + recipe_scanner=recipe_scanner, + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=404) except Exception as exc: self._logger.error("Error analyzing local image: %s", exc, exc_info=True) return web.json_response({"error": str(exc), "loras": []}, status=500) @@ -1242,11 +731,12 @@ class RecipeSharingHandler: ensure_dependencies_ready: EnsureDependenciesCallable, recipe_scanner_getter: RecipeScannerGetter, logger: Logger, + sharing_service: RecipeSharingService, ) -> None: self._ensure_dependencies_ready = ensure_dependencies_ready self._recipe_scanner_getter = recipe_scanner_getter self._logger = logger - self._shared_recipes: Dict[str, Dict[str, Any]] = {} + self._sharing_service = sharing_service async def share_recipe(self, request: web.Request) -> web.Response: try: @@ -1256,42 +746,17 @@ class RecipeSharingHandler: raise RuntimeError("Recipe scanner unavailable") recipe_id = request.match_info["recipe_id"] - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, + result = await self._sharing_service.share_recipe( + recipe_scanner=recipe_scanner, recipe_id=recipe_id ) - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - image_path = recipe.get("file_path") - if not image_path or not os.path.exists(image_path): - return web.json_response({"error": "Recipe image not found"}, status=404) - - import shutil - - ext = os.path.splitext(image_path)[1] - with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file: - temp_path = temp_file.name - shutil.copy2(image_path, temp_path) - processed_path = temp_path - - timestamp = int(time.time()) - url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}" - self._shared_recipes[recipe_id] = { - "path": processed_path, - "timestamp": timestamp, - "expires": time.time() + 300, - } - self._cleanup_shared_recipes() - - filename = f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}{ext}" - return web.json_response({"success": True, "download_url": url_path, "filename": filename}) + return web.json_response(result.payload, status=result.status) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) except Exception as exc: self._logger.error("Error sharing recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) - async def download_shared_recipe(self, request: web.Request) -> web.Response: + async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse: try: await self._ensure_dependencies_ready() recipe_scanner = self._recipe_scanner_getter() @@ -1299,49 +764,17 @@ class RecipeSharingHandler: raise RuntimeError("Recipe scanner unavailable") recipe_id = request.match_info["recipe_id"] - shared_info = self._shared_recipes.get(recipe_id) - if not shared_info: - return web.json_response({"error": "Shared recipe not found or expired"}, status=404) - - file_path = shared_info["path"] - if not os.path.exists(file_path): - return web.json_response({"error": "Shared recipe file not found"}, status=404) - - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, + download_info = await self._sharing_service.prepare_download( + recipe_scanner=recipe_scanner, recipe_id=recipe_id ) - filename_base = ( - f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" - if recipe - else recipe_id - ) - ext = os.path.splitext(file_path)[1] - download_filename = f"{filename_base}{ext}" - return web.FileResponse( - file_path, - headers={"Content-Disposition": f'attachment; filename="{download_filename}"'}, + download_info.file_path, + headers={ + "Content-Disposition": f'attachment; filename="{download_info.download_filename}"' + }, ) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) except Exception as exc: self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) - - def _cleanup_shared_recipes(self) -> None: - current_time = time.time() - expired_ids = [ - recipe_id - for recipe_id, info in self._shared_recipes.items() - if current_time > info.get("expires", 0) - ] - - for recipe_id in expired_ids: - try: - file_path = self._shared_recipes[recipe_id]["path"] - if os.path.exists(file_path): - os.unlink(file_path) - except Exception as exc: # pragma: no cover - logging path - self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc) - finally: - self._shared_recipes.pop(recipe_id, None) diff --git a/py/services/recipes/__init__.py b/py/services/recipes/__init__.py new file mode 100644 index 00000000..8009b7c3 --- /dev/null +++ b/py/services/recipes/__init__.py @@ -0,0 +1,23 @@ +"""Recipe service layer implementations.""" + +from .analysis_service import RecipeAnalysisService +from .persistence_service import RecipePersistenceService +from .sharing_service import RecipeSharingService +from .errors import ( + RecipeServiceError, + RecipeValidationError, + RecipeNotFoundError, + RecipeDownloadError, + RecipeConflictError, +) + +__all__ = [ + "RecipeAnalysisService", + "RecipePersistenceService", + "RecipeSharingService", + "RecipeServiceError", + "RecipeValidationError", + "RecipeNotFoundError", + "RecipeDownloadError", + "RecipeConflictError", +] diff --git a/py/services/recipes/analysis_service.py b/py/services/recipes/analysis_service.py new file mode 100644 index 00000000..77d80e34 --- /dev/null +++ b/py/services/recipes/analysis_service.py @@ -0,0 +1,289 @@ +"""Services responsible for recipe metadata analysis.""" +from __future__ import annotations + +import base64 +import io +import os +import re +import tempfile +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import numpy as np +from PIL import Image + +from ...utils.utils import calculate_recipe_fingerprint +from .errors import ( + RecipeDownloadError, + RecipeNotFoundError, + RecipeServiceError, + RecipeValidationError, +) + + +@dataclass(frozen=True) +class AnalysisResult: + """Return payload from analysis operations.""" + + payload: dict[str, Any] + status: int = 200 + + +class RecipeAnalysisService: + """Extract recipe metadata from various image sources.""" + + def __init__( + self, + *, + exif_utils, + recipe_parser_factory, + downloader_factory: Callable[[], Any], + metadata_collector: Optional[Callable[[], Any]] = None, + metadata_processor_cls: Optional[type] = None, + metadata_registry_cls: Optional[type] = None, + standalone_mode: bool = False, + logger, + ) -> None: + self._exif_utils = exif_utils + self._recipe_parser_factory = recipe_parser_factory + self._downloader_factory = downloader_factory + self._metadata_collector = metadata_collector + self._metadata_processor_cls = metadata_processor_cls + self._metadata_registry_cls = metadata_registry_cls + self._standalone_mode = standalone_mode + self._logger = logger + + async def analyze_uploaded_image( + self, + *, + image_bytes: bytes | None, + recipe_scanner, + ) -> AnalysisResult: + """Analyze an uploaded image payload.""" + + if not image_bytes: + raise RecipeValidationError("No image data provided") + + temp_path = self._write_temp_file(image_bytes) + try: + metadata = self._exif_utils.extract_image_metadata(temp_path) + if not metadata: + return AnalysisResult({"error": "No metadata found in this image", "loras": []}) + + return await self._parse_metadata( + metadata, + recipe_scanner=recipe_scanner, + image_path=None, + include_image_base64=False, + ) + finally: + self._safe_cleanup(temp_path) + + async def analyze_remote_image( + self, + *, + url: str | None, + recipe_scanner, + civitai_client, + ) -> AnalysisResult: + """Analyze an image accessible via URL, including Civitai integration.""" + + if not url: + raise RecipeValidationError("No URL provided") + + if civitai_client is None: + raise RecipeServiceError("Civitai client unavailable") + + temp_path = self._create_temp_path() + metadata: Optional[dict[str, Any]] = None + try: + civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url) + if civitai_match: + image_info = await civitai_client.get_image_info(civitai_match.group(1)) + if not image_info: + raise RecipeDownloadError("Failed to fetch image information from Civitai") + image_url = image_info.get("url") + if not image_url: + raise RecipeDownloadError("No image URL found in Civitai response") + await self._download_image(image_url, temp_path) + metadata = image_info.get("meta") if "meta" in image_info else None + else: + await self._download_image(url, temp_path) + + if metadata is None: + metadata = self._exif_utils.extract_image_metadata(temp_path) + + if not metadata: + return self._metadata_not_found_response(temp_path) + + return await self._parse_metadata( + metadata, + recipe_scanner=recipe_scanner, + image_path=temp_path, + include_image_base64=True, + ) + finally: + self._safe_cleanup(temp_path) + + async def analyze_local_image( + self, + *, + file_path: str | None, + recipe_scanner, + ) -> AnalysisResult: + """Analyze a file already present on disk.""" + + if not file_path: + raise RecipeValidationError("No file path provided") + + normalized_path = os.path.normpath(file_path.strip('"').strip("'")) + if not os.path.isfile(normalized_path): + raise RecipeNotFoundError("File not found") + + metadata = self._exif_utils.extract_image_metadata(normalized_path) + if not metadata: + return self._metadata_not_found_response(normalized_path) + + return await self._parse_metadata( + metadata, + recipe_scanner=recipe_scanner, + image_path=normalized_path, + include_image_base64=True, + ) + + async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult: + """Analyse the most recent generation metadata for widget saves.""" + + if self._metadata_collector is None or self._metadata_processor_cls is None: + raise RecipeValidationError("Metadata collection not available") + + raw_metadata = self._metadata_collector() + metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata) + if not metadata_dict: + raise RecipeValidationError("No generation metadata found") + + latest_image = None + if not self._standalone_mode and self._metadata_registry_cls is not None: + metadata_registry = self._metadata_registry_cls() + latest_image = metadata_registry.get_first_decoded_image() + + if latest_image is None: + raise RecipeValidationError( + "No recent images found to use for recipe. Try generating an image first." + ) + + image_bytes = self._convert_tensor_to_png_bytes(latest_image) + if image_bytes is None: + raise RecipeValidationError("Cannot handle this data shape from metadata registry") + + return AnalysisResult( + { + "metadata": metadata_dict, + "image_bytes": image_bytes, + } + ) + + # Internal helpers ------------------------------------------------- + + async def _parse_metadata( + self, + metadata: dict[str, Any], + *, + recipe_scanner, + image_path: Optional[str], + include_image_base64: bool, + ) -> AnalysisResult: + parser = self._recipe_parser_factory.create_parser(metadata) + if parser is None: + payload = {"error": "No parser found for this image", "loras": []} + if include_image_base64 and image_path: + payload["image_base64"] = self._encode_file(image_path) + return AnalysisResult(payload) + + result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner) + + if include_image_base64 and image_path: + result["image_base64"] = self._encode_file(image_path) + + if "error" in result and not result.get("loras"): + return AnalysisResult(result) + + fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) + result["fingerprint"] = fingerprint + + matching_recipes: list[str] = [] + if fingerprint: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + result["matching_recipes"] = matching_recipes + + return AnalysisResult(result) + + async def _download_image(self, url: str, temp_path: str) -> None: + downloader = await self._downloader_factory() + success, result = await downloader.download_file(url, temp_path, use_auth=False) + if not success: + raise RecipeDownloadError(f"Failed to download image from URL: {result}") + + def _metadata_not_found_response(self, path: str) -> AnalysisResult: + payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []} + if os.path.exists(path): + payload["image_base64"] = self._encode_file(path) + return AnalysisResult(payload) + + def _write_temp_file(self, data: bytes) -> str: + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + temp_file.write(data) + return temp_file.name + + def _create_temp_path(self) -> str: + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + return temp_file.name + + def _safe_cleanup(self, path: Optional[str]) -> None: + if path and os.path.exists(path): + try: + os.unlink(path) + except Exception as exc: # pragma: no cover - defensive logging + self._logger.error("Error deleting temporary file: %s", exc) + + def _encode_file(self, path: str) -> str: + with open(path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]: + try: + if isinstance(latest_image, tuple): + tensor_image = latest_image[0] if latest_image else None + if tensor_image is None: + return None + else: + tensor_image = latest_image + + if hasattr(tensor_image, "shape"): + self._logger.debug( + "Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None) + ) + + import torch # type: ignore[import-not-found] + + if isinstance(tensor_image, torch.Tensor): + image_np = tensor_image.cpu().numpy() + else: + image_np = np.array(tensor_image) + + while len(image_np.shape) > 3: + image_np = image_np[0] + + if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + + if len(image_np.shape) == 3 and image_np.shape[2] == 3: + pil_image = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format="PNG") + return img_byte_arr.getvalue() + except Exception as exc: # pragma: no cover - defensive logging path + self._logger.error("Error processing image data: %s", exc, exc_info=True) + return None + + return None diff --git a/py/services/recipes/errors.py b/py/services/recipes/errors.py new file mode 100644 index 00000000..9e5d9720 --- /dev/null +++ b/py/services/recipes/errors.py @@ -0,0 +1,22 @@ +"""Shared exceptions for recipe services.""" +from __future__ import annotations + + +class RecipeServiceError(Exception): + """Base exception for recipe service failures.""" + + +class RecipeValidationError(RecipeServiceError): + """Raised when a request payload fails validation.""" + + +class RecipeNotFoundError(RecipeServiceError): + """Raised when a recipe resource cannot be located.""" + + +class RecipeDownloadError(RecipeServiceError): + """Raised when remote recipe assets cannot be downloaded.""" + + +class RecipeConflictError(RecipeServiceError): + """Raised when a conflicting recipe state is detected.""" diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py new file mode 100644 index 00000000..945680df --- /dev/null +++ b/py/services/recipes/persistence_service.py @@ -0,0 +1,467 @@ +"""Services encapsulating recipe persistence workflows.""" +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import re +import time +import uuid +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Optional + +from ...config import config +from ...utils.utils import calculate_recipe_fingerprint +from .errors import RecipeNotFoundError, RecipeValidationError + + +@dataclass(frozen=True) +class PersistenceResult: + """Return payload from persistence operations.""" + + payload: dict[str, Any] + status: int = 200 + + +class RecipePersistenceService: + """Coordinate recipe persistence tasks across storage and caches.""" + + def __init__( + self, + *, + exif_utils, + card_preview_width: int, + logger, + ) -> None: + self._exif_utils = exif_utils + self._card_preview_width = card_preview_width + self._logger = logger + + async def save_recipe( + self, + *, + recipe_scanner, + image_bytes: bytes | None, + image_base64: str | None, + name: str | None, + tags: Iterable[str], + metadata: Optional[dict[str, Any]], + ) -> PersistenceResult: + """Persist a user uploaded recipe.""" + + missing_fields = [] + if not name: + missing_fields.append("name") + if metadata is None: + missing_fields.append("metadata") + if missing_fields: + raise RecipeValidationError( + f"Missing required fields: {', '.join(missing_fields)}" + ) + + resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64) + recipes_dir = recipe_scanner.recipes_dir + os.makedirs(recipes_dir, exist_ok=True) + + recipe_id = str(uuid.uuid4()) + optimized_image, extension = self._exif_utils.optimize_image( + image_data=resolved_image_bytes, + target_width=self._card_preview_width, + format="webp", + quality=85, + preserve_metadata=True, + ) + image_filename = f"{recipe_id}{extension}" + image_path = os.path.join(recipes_dir, image_filename) + with open(image_path, "wb") as file_obj: + file_obj.write(optimized_image) + + current_time = time.time() + loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])] + + gen_params = metadata.get("gen_params", {}) + if not gen_params and "raw_metadata" in metadata: + raw_metadata = metadata.get("raw_metadata", {}) + gen_params = { + "prompt": raw_metadata.get("prompt", ""), + "negative_prompt": raw_metadata.get("negative_prompt", ""), + "checkpoint": raw_metadata.get("checkpoint", {}), + "steps": raw_metadata.get("steps", ""), + "sampler": raw_metadata.get("sampler", ""), + "cfg_scale": raw_metadata.get("cfg_scale", ""), + "seed": raw_metadata.get("seed", ""), + "size": raw_metadata.get("size", ""), + "clip_skip": raw_metadata.get("clip_skip", ""), + } + + fingerprint = calculate_recipe_fingerprint(loras_data) + recipe_data: Dict[str, Any] = { + "id": recipe_id, + "file_path": image_path, + "title": name, + "modified": current_time, + "created_date": current_time, + "base_model": metadata.get("base_model", ""), + "loras": loras_data, + "gen_params": gen_params, + "fingerprint": fingerprint, + } + + tags_list = list(tags) + if tags_list: + recipe_data["tags"] = tags_list + + if metadata.get("source_path"): + recipe_data["source_path"] = metadata.get("source_path") + + json_filename = f"{recipe_id}.recipe.json" + json_path = os.path.join(recipes_dir, json_filename) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id) + await self._update_cache(recipe_scanner, recipe_data) + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "image_path": image_path, + "json_path": json_path, + "matching_recipes": matching_recipes, + } + ) + + async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult: + """Delete an existing recipe.""" + + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir or not os.path.exists(recipes_dir): + raise RecipeNotFoundError("Recipes directory not found") + + recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + raise RecipeNotFoundError("Recipe not found") + + with open(recipe_json_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + + image_path = recipe_data.get("file_path") + os.remove(recipe_json_path) + if image_path and os.path.exists(image_path): + os.remove(image_path) + + await self._remove_from_cache(recipe_scanner, recipe_id) + return PersistenceResult({"success": True, "message": "Recipe deleted successfully"}) + + async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult: + """Update persisted metadata for a recipe.""" + + if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")): + raise RecipeValidationError( + "At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)" + ) + + success = await recipe_scanner.update_recipe_metadata(recipe_id, updates) + if not success: + raise RecipeNotFoundError("Recipe not found or update failed") + + return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates}) + + async def reconnect_lora( + self, + *, + recipe_scanner, + recipe_id: str, + lora_index: int, + target_name: str, + ) -> PersistenceResult: + """Reconnect a LoRA entry within an existing recipe.""" + + recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_path): + raise RecipeNotFoundError("Recipe not found") + + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + target_lora = None if lora_scanner is None else await lora_scanner.get_model_info_by_name(target_name) + if not target_lora: + raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}") + + with open(recipe_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + + loras = recipe_data.get("loras", []) + if lora_index >= len(loras): + raise RecipeNotFoundError("LoRA index out of range in recipe") + + lora = loras[lora_index] + lora["isDeleted"] = False + lora["exclude"] = False + lora["file_name"] = target_name + if "sha256" in target_lora: + lora["hash"] = target_lora["sha256"].lower() + if target_lora.get("civitai"): + lora["modelName"] = target_lora["civitai"]["model"]["name"] + lora["modelVersionName"] = target_lora["civitai"]["name"] + lora["modelVersionId"] = target_lora["civitai"]["id"] + + recipe_data["fingerprint"] = calculate_recipe_fingerprint(recipe_data.get("loras", [])) + + with open(recipe_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + updated_lora = dict(lora) + updated_lora["inLibrary"] = True + updated_lora["preview_url"] = config.get_preview_static_url(target_lora["preview_url"]) + updated_lora["localPath"] = target_lora["file_path"] + + await self._refresh_cache_after_update(recipe_scanner, recipe_id, recipe_data) + + image_path = recipe_data.get("file_path") + if image_path and os.path.exists(image_path): + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + matching_recipes = [] + if "fingerprint" in recipe_data: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"]) + if recipe_id in matching_recipes: + matching_recipes.remove(recipe_id) + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "updated_lora": updated_lora, + "matching_recipes": matching_recipes, + } + ) + + async def bulk_delete( + self, + *, + recipe_scanner, + recipe_ids: Iterable[str], + ) -> PersistenceResult: + """Delete multiple recipes in a single request.""" + + recipe_ids = list(recipe_ids) + if not recipe_ids: + raise RecipeValidationError("No recipe IDs provided") + + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir or not os.path.exists(recipes_dir): + raise RecipeNotFoundError("Recipes directory not found") + + deleted_recipes: list[str] = [] + failed_recipes: list[dict[str, Any]] = [] + + for recipe_id in recipe_ids: + recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"}) + continue + + try: + with open(recipe_json_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + image_path = recipe_data.get("file_path") + os.remove(recipe_json_path) + if image_path and os.path.exists(image_path): + os.remove(image_path) + deleted_recipes.append(recipe_id) + except Exception as exc: + failed_recipes.append({"id": recipe_id, "reason": str(exc)}) + + if deleted_recipes: + await self._bulk_remove_from_cache(recipe_scanner, deleted_recipes) + + return PersistenceResult( + { + "success": True, + "deleted": deleted_recipes, + "failed": failed_recipes, + "total_deleted": len(deleted_recipes), + "total_failed": len(failed_recipes), + } + ) + + async def save_recipe_from_widget( + self, + *, + recipe_scanner, + metadata: dict[str, Any], + image_bytes: bytes, + ) -> PersistenceResult: + """Save a recipe constructed from widget metadata.""" + + if not metadata: + raise RecipeValidationError("No generation metadata found") + + recipes_dir = recipe_scanner.recipes_dir + os.makedirs(recipes_dir, exist_ok=True) + + recipe_id = str(uuid.uuid4()) + image_filename = f"{recipe_id}.png" + image_path = os.path.join(recipes_dir, image_filename) + with open(image_path, "wb") as file_obj: + file_obj.write(image_bytes) + + lora_stack = metadata.get("loras", "") + lora_matches = re.findall(r"]+)>", lora_stack) + if not lora_matches: + raise RecipeValidationError("No LoRAs found in the generation metadata") + + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + loras_data = [] + base_model_counts: Dict[str, int] = {} + + for name, strength in lora_matches: + lora_info = None + if lora_scanner is not None: + lora_info = await lora_scanner.get_model_info_by_name(name) + lora_data = { + "file_name": name, + "strength": float(strength), + "hash": (lora_info.get("sha256") or "").lower() if lora_info else "", + "modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0, + "modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "", + "modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "", + "isDeleted": False, + "exclude": False, + } + loras_data.append(lora_data) + + if lora_info and "base_model" in lora_info: + base_model = lora_info["base_model"] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + recipe_name = self._derive_recipe_name(lora_matches) + most_common_base_model = ( + max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else "" + ) + + recipe_data = { + "id": recipe_id, + "file_path": image_path, + "title": recipe_name, + "modified": time.time(), + "created_date": time.time(), + "base_model": most_common_base_model, + "loras": loras_data, + "checkpoint": metadata.get("checkpoint", ""), + "gen_params": { + key: value + for key, value in metadata.items() + if key not in ["checkpoint", "loras"] + }, + "loras_stack": lora_stack, + } + + json_filename = f"{recipe_id}.recipe.json" + json_path = os.path.join(recipes_dir, json_filename) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + await self._update_cache(recipe_scanner, recipe_data) + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "image_path": image_path, + "json_path": json_path, + "recipe_name": recipe_name, + } + ) + + # Helper methods --------------------------------------------------- + + def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes: + if image_bytes is not None: + return image_bytes + if image_base64: + try: + payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64 + return base64.b64decode(payload) + except Exception as exc: # pragma: no cover - validation guard + raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc + raise RecipeValidationError("No image data provided") + + def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]: + return { + "file_name": lora.get("file_name", "") + or ( + os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] + if lora.get("localPath") + else "" + ), + "hash": (lora.get("hash") or "").lower(), + "strength": float(lora.get("weight", 1.0)), + "modelVersionId": lora.get("id", 0), + "modelName": lora.get("name", ""), + "modelVersionName": lora.get("version", ""), + "isDeleted": lora.get("isDeleted", False), + "exclude": lora.get("exclude", False), + } + + async def _find_matching_recipes( + self, + recipe_scanner, + fingerprint: str | None, + *, + exclude_id: Optional[str] = None, + ) -> list[str]: + if not fingerprint: + return [] + matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + if exclude_id and exclude_id in matches: + matches.remove(exclude_id) + return matches + + async def _update_cache(self, recipe_scanner, recipe_data: dict[str, Any]) -> None: + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + cache.raw_data.append(recipe_data) + asyncio.create_task(cache.resort()) + self._logger.info("Added recipe %s to cache", recipe_data.get("id")) + + async def _remove_from_cache(self, recipe_scanner, recipe_id: str) -> None: + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + cache.raw_data = [item for item in cache.raw_data if str(item.get("id", "")) != recipe_id] + asyncio.create_task(cache.resort()) + self._logger.info("Removed recipe %s from cache", recipe_id) + + async def _bulk_remove_from_cache(self, recipe_scanner, recipe_ids: Iterable[str]) -> None: + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + recipe_ids_set = set(recipe_ids) + cache.raw_data = [item for item in cache.raw_data if item.get("id") not in recipe_ids_set] + asyncio.create_task(cache.resort()) + self._logger.info("Removed %s recipes from cache", len(recipe_ids_set)) + + async def _refresh_cache_after_update( + self, + recipe_scanner, + recipe_id: str, + recipe_data: dict[str, Any], + ) -> None: + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + for cache_item in cache.raw_data: + if cache_item.get("id") == recipe_id: + cache_item.update({ + "loras": recipe_data.get("loras", []), + "fingerprint": recipe_data.get("fingerprint"), + }) + asyncio.create_task(cache.resort()) + break + + def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str: + recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]] + recipe_name = "_".join(recipe_name_parts) + return recipe_name or "recipe" diff --git a/py/services/recipes/sharing_service.py b/py/services/recipes/sharing_service.py new file mode 100644 index 00000000..7c365bba --- /dev/null +++ b/py/services/recipes/sharing_service.py @@ -0,0 +1,113 @@ +"""Services handling recipe sharing and downloads.""" +from __future__ import annotations + +import os +import shutil +import tempfile +import time +from dataclasses import dataclass +from typing import Any, Dict + +from .errors import RecipeNotFoundError + + +@dataclass(frozen=True) +class SharingResult: + """Return payload for share operations.""" + + payload: dict[str, Any] + status: int = 200 + + +@dataclass(frozen=True) +class DownloadInfo: + """Information required to stream a shared recipe file.""" + + file_path: str + download_filename: str + + +class RecipeSharingService: + """Prepare temporary recipe downloads with TTL cleanup.""" + + def __init__(self, *, ttl_seconds: int = 300, logger) -> None: + self._ttl_seconds = ttl_seconds + self._logger = logger + self._shared_recipes: Dict[str, Dict[str, Any]] = {} + + async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult: + """Prepare a temporary downloadable copy of a recipe image.""" + + cache = await recipe_scanner.get_cached_data() + recipe = next( + (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), + None, + ) + if not recipe: + raise RecipeNotFoundError("Recipe not found") + + image_path = recipe.get("file_path") + if not image_path or not os.path.exists(image_path): + raise RecipeNotFoundError("Recipe image not found") + + ext = os.path.splitext(image_path)[1] + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file: + temp_path = temp_file.name + + shutil.copy2(image_path, temp_path) + timestamp = int(time.time()) + self._shared_recipes[recipe_id] = { + "path": temp_path, + "timestamp": timestamp, + "expires": time.time() + self._ttl_seconds, + } + self._cleanup_shared_recipes() + + safe_title = recipe.get("title", "").replace(" ", "_").lower() + filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}" + url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}" + return SharingResult({"success": True, "download_url": url_path, "filename": filename}) + + async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> DownloadInfo: + """Return file path and filename for a prepared shared recipe.""" + + shared_info = self._shared_recipes.get(recipe_id) + if not shared_info or time.time() > shared_info.get("expires", 0): + self._cleanup_entry(recipe_id) + raise RecipeNotFoundError("Shared recipe not found or expired") + + file_path = shared_info["path"] + if not os.path.exists(file_path): + self._cleanup_entry(recipe_id) + raise RecipeNotFoundError("Shared recipe file not found") + + cache = await recipe_scanner.get_cached_data() + recipe = next( + (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), + None, + ) + filename_base = ( + f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id + ) + ext = os.path.splitext(file_path)[1] + download_filename = f"{filename_base}{ext}" + return DownloadInfo(file_path=file_path, download_filename=download_filename) + + def _cleanup_shared_recipes(self) -> None: + for recipe_id in list(self._shared_recipes.keys()): + shared = self._shared_recipes.get(recipe_id) + if not shared: + continue + if time.time() > shared.get("expires", 0): + self._cleanup_entry(recipe_id) + + def _cleanup_entry(self, recipe_id: str) -> None: + shared_info = self._shared_recipes.pop(recipe_id, None) + if not shared_info: + return + file_path = shared_info.get("path") + if file_path and os.path.exists(file_path): + try: + os.unlink(file_path) + except Exception as exc: # pragma: no cover - defensive logging + self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc) diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py new file mode 100644 index 00000000..e57abf2f --- /dev/null +++ b/tests/services/test_recipe_services.py @@ -0,0 +1,146 @@ +import logging +import os +from types import SimpleNamespace + +import pytest + +from py.services.recipes.analysis_service import RecipeAnalysisService +from py.services.recipes.errors import RecipeDownloadError, RecipeNotFoundError +from py.services.recipes.persistence_service import RecipePersistenceService + + +class DummyExifUtils: + def optimize_image(self, image_data, target_width, format, quality, preserve_metadata): + return image_data, ".webp" + + def append_recipe_metadata(self, image_path, recipe_data): + self.appended = (image_path, recipe_data) + + def extract_image_metadata(self, path): + return {} + + +@pytest.mark.asyncio +async def test_analyze_remote_image_download_failure_cleans_temp(tmp_path, monkeypatch): + exif_utils = DummyExifUtils() + + class DummyFactory: + def create_parser(self, metadata): + return None + + async def downloader_factory(): + class Downloader: + async def download_file(self, url, path, use_auth=False): + return False, "failure" + + return Downloader() + + service = RecipeAnalysisService( + exif_utils=exif_utils, + recipe_parser_factory=DummyFactory(), + downloader_factory=downloader_factory, + metadata_collector=None, + metadata_processor_cls=None, + metadata_registry_cls=None, + standalone_mode=False, + logger=logging.getLogger("test"), + ) + + temp_path = tmp_path / "temp.jpg" + + def create_temp_path(): + temp_path.write_bytes(b"") + return str(temp_path) + + monkeypatch.setattr(service, "_create_temp_path", create_temp_path) + + with pytest.raises(RecipeDownloadError): + await service.analyze_remote_image( + url="https://example.com/image.jpg", + recipe_scanner=SimpleNamespace(), + civitai_client=SimpleNamespace(), + ) + + assert not temp_path.exists(), "temporary file should be cleaned after failure" + + +@pytest.mark.asyncio +async def test_analyze_local_image_missing_file(tmp_path): + async def downloader_factory(): + return SimpleNamespace() + + service = RecipeAnalysisService( + exif_utils=DummyExifUtils(), + recipe_parser_factory=SimpleNamespace(create_parser=lambda metadata: None), + downloader_factory=downloader_factory, + metadata_collector=None, + metadata_processor_cls=None, + metadata_registry_cls=None, + standalone_mode=False, + logger=logging.getLogger("test"), + ) + + with pytest.raises(RecipeNotFoundError): + await service.analyze_local_image( + file_path=str(tmp_path / "missing.png"), + recipe_scanner=SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_save_recipe_reports_duplicates(tmp_path): + exif_utils = DummyExifUtils() + + class DummyCache: + def __init__(self): + self.raw_data = [] + + async def resort(self): + pass + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + self._cache = DummyCache() + self.last_fingerprint = None + + async def find_recipes_by_fingerprint(self, fingerprint): + self.last_fingerprint = fingerprint + return ["existing"] + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + metadata = { + "base_model": "sd", + "loras": [ + { + "file_name": "sample", + "hash": "abc123", + "weight": 0.5, + "id": 1, + "name": "Sample", + "version": "v1", + "isDeleted": False, + "exclude": False, + } + ], + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"image-bytes", + image_base64=None, + name="My Recipe", + tags=["tag"], + metadata=metadata, + ) + + assert result.payload["matching_recipes"] == ["existing"] + assert scanner.last_fingerprint is not None + assert os.path.exists(result.payload["json_path"]) + assert scanner._cache.raw_data From 42872e6d2de7f12f003deafa92aeafbee31d9737 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 13:45:40 +0800 Subject: [PATCH 13/24] feat(recipes): expose recipe scanner mutation apis --- py/routes/handlers/recipe_handlers.py | 67 +----- py/services/recipe_cache.py | 146 +++++++----- py/services/recipe_scanner.py | 255 +++++++++++++++++++-- py/services/recipes/persistence_service.py | 91 +------- py/services/recipes/sharing_service.py | 12 +- tests/services/test_recipe_scanner.py | 185 +++++++++++++++ tests/services/test_recipe_services.py | 4 + 7 files changed, 532 insertions(+), 228 deletions(-) create mode 100644 tests/services/test_recipe_scanner.py diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 35f4c088..aa912477 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -290,27 +290,7 @@ class RecipeQueryHandler: if not lora_hash: return web.json_response({"success": False, "error": "Lora hash is required"}, status=400) - cache = await recipe_scanner.get_cached_data() - matching_recipes = [] - for recipe in getattr(cache, "raw_data", []): - for lora in recipe.get("loras", []): - if lora.get("hash", "").lower() == lora_hash.lower(): - matching_recipes.append(recipe) - break - - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - for recipe in matching_recipes: - for lora in recipe.get("loras", []): - hash_value = (lora.get("hash") or "").lower() - if hash_value and lora_scanner is not None: - lora["inLibrary"] = lora_scanner.has_hash(hash_value) - lora["preview_url"] = lora_scanner.get_preview_url_by_hash(hash_value) - lora["localPath"] = lora_scanner.get_path_by_hash(hash_value) - if recipe.get("file_path"): - recipe["file_url"] = self._format_recipe_file_url(recipe["file_path"]) - else: - recipe["file_url"] = "/loras_static/images/no-preview.png" - + matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash) return web.json_response({"success": True, "recipes": matching_recipes}) except Exception as exc: self._logger.error("Error getting recipes for Lora: %s", exc) @@ -384,50 +364,15 @@ class RecipeQueryHandler: raise RuntimeError("Recipe scanner unavailable") recipe_id = request.match_info["recipe_id"] - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, - ) - if not recipe: + try: + syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id) + except RecipeNotFoundError: return web.json_response({"error": "Recipe not found"}, status=404) - loras = recipe.get("loras", []) - if not loras: + if not syntax_parts: return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - hash_index = getattr(lora_scanner, "_hash_index", None) - - lora_syntax_parts = [] - for lora in loras: - if lora.get("isDeleted", False): - continue - hash_value = (lora.get("hash") or "").lower() - if not hash_value or lora_scanner is None or not lora_scanner.has_hash(hash_value): - continue - - file_name = None - if hash_value and hash_index is not None and hasattr(hash_index, "_hash_to_path"): - file_path = hash_index._hash_to_path.get(hash_value) - if file_path: - file_name = os.path.splitext(os.path.basename(file_path))[0] - - if not file_name and lora.get("modelVersionId") and lora_scanner is not None: - all_loras = await lora_scanner.get_cached_data() - for cached_lora in getattr(all_loras, "raw_data", []): - civitai_info = cached_lora.get("civitai") - if civitai_info and civitai_info.get("id") == lora.get("modelVersionId"): - file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0] - break - - if not file_name: - file_name = lora.get("file_name", "unknown-lora") - - strength = lora.get("strength", 1.0) - lora_syntax_parts.append(f"") - - return web.json_response({"success": True, "syntax": " ".join(lora_syntax_parts)}) + return web.json_response({"success": True, "syntax": " ".join(syntax_parts)}) except Exception as exc: self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) diff --git a/py/services/recipe_cache.py b/py/services/recipe_cache.py index b1f52246..ac28b3aa 100644 --- a/py/services/recipe_cache.py +++ b/py/services/recipe_cache.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict +from typing import Iterable, List, Dict, Optional from dataclasses import dataclass from operator import itemgetter from natsort import natsorted @@ -10,77 +10,115 @@ class RecipeCache: raw_data: List[Dict] sorted_by_name: List[Dict] sorted_by_date: List[Dict] - + def __post_init__(self): self._lock = asyncio.Lock() async def resort(self, name_only: bool = False): """Resort all cached data views""" async with self._lock: - self.sorted_by_name = natsorted( - self.raw_data, - key=lambda x: x.get('title', '').lower() # Case-insensitive sort - ) - if not name_only: - self.sorted_by_date = sorted( - self.raw_data, - key=itemgetter('created_date', 'file_path'), - reverse=True - ) - - async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool: + self._resort_locked(name_only=name_only) + + async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool: """Update metadata for a specific recipe in all cached data - + Args: recipe_id: The ID of the recipe to update metadata: The new metadata - + Returns: bool: True if the update was successful, False if the recipe wasn't found """ + async with self._lock: + for item in self.raw_data: + if str(item.get('id')) == str(recipe_id): + item.update(metadata) + if resort: + self._resort_locked() + return True + return False # Recipe not found + + async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None: + """Add a new recipe to the cache.""" - # Update in raw_data - for item in self.raw_data: - if item.get('id') == recipe_id: - item.update(metadata) - break - else: - return False # Recipe not found - - # Resort to reflect changes - await self.resort() - return True - - async def add_recipe(self, recipe_data: Dict) -> None: - """Add a new recipe to the cache - - Args: - recipe_data: The recipe data to add - """ async with self._lock: self.raw_data.append(recipe_data) - await self.resort() + if resort: + self._resort_locked() + + async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]: + """Remove a recipe from the cache by ID. - async def remove_recipe(self, recipe_id: str) -> bool: - """Remove a recipe from the cache by ID - Args: recipe_id: The ID of the recipe to remove - + Returns: - bool: True if the recipe was found and removed, False otherwise + The removed recipe data if found, otherwise ``None``. """ - # Find the recipe in raw_data - recipe_index = next((i for i, recipe in enumerate(self.raw_data) - if recipe.get('id') == recipe_id), None) - - if recipe_index is None: - return False - - # Remove from raw_data - self.raw_data.pop(recipe_index) - - # Resort to update sorted lists - await self.resort() - - return True \ No newline at end of file + + async with self._lock: + for index, recipe in enumerate(self.raw_data): + if str(recipe.get('id')) == str(recipe_id): + removed = self.raw_data.pop(index) + if resort: + self._resort_locked() + return removed + return None + + async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]: + """Remove multiple recipes from the cache.""" + + id_set = {str(recipe_id) for recipe_id in recipe_ids} + if not id_set: + return [] + + async with self._lock: + removed = [item for item in self.raw_data if str(item.get('id')) in id_set] + if not removed: + return [] + + self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set] + if resort: + self._resort_locked() + return removed + + async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool: + """Replace cached data for a recipe.""" + + async with self._lock: + for index, recipe in enumerate(self.raw_data): + if str(recipe.get('id')) == str(recipe_id): + self.raw_data[index] = new_data + if resort: + self._resort_locked() + return True + return False + + async def get_recipe(self, recipe_id: str) -> Optional[Dict]: + """Return a shallow copy of a cached recipe.""" + + async with self._lock: + for recipe in self.raw_data: + if str(recipe.get('id')) == str(recipe_id): + return dict(recipe) + return None + + async def snapshot(self) -> List[Dict]: + """Return a copy of all cached recipes.""" + + async with self._lock: + return [dict(item) for item in self.raw_data] + + def _resort_locked(self, *, name_only: bool = False) -> None: + """Sort cached views. Caller must hold ``_lock``.""" + + self.sorted_by_name = natsorted( + self.raw_data, + key=lambda x: x.get('title', '').lower() + ) + if not name_only: + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('created_date', 'file_path'), + reverse=True + ) \ No newline at end of file diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ca5a20ac..9a82b237 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -3,13 +3,14 @@ import logging import asyncio import json import time -from typing import List, Dict, Optional, Any, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from ..config import config from .recipe_cache import RecipeCache from .service_registry import ServiceRegistry from .lora_scanner import LoraScanner from .metadata_service import get_default_metadata_provider -from ..utils.utils import fuzzy_match +from .recipes.errors import RecipeNotFoundError +from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match from natsort import natsorted import sys @@ -46,6 +47,8 @@ class RecipeScanner: self._initialization_lock = asyncio.Lock() self._initialization_task: Optional[asyncio.Task] = None self._is_initializing = False + self._mutation_lock = asyncio.Lock() + self._resort_tasks: Set[asyncio.Task] = set() if lora_scanner: self._lora_scanner = lora_scanner self._initialized = True @@ -191,6 +194,22 @@ class RecipeScanner: # Clean up the event loop loop.close() + def _schedule_resort(self, *, name_only: bool = False) -> None: + """Schedule a background resort of the recipe cache.""" + + if not self._cache: + return + + async def _resort_wrapper() -> None: + try: + await self._cache.resort(name_only=name_only) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True) + + task = asyncio.create_task(_resort_wrapper()) + self._resort_tasks.add(task) + task.add_done_callback(lambda finished: self._resort_tasks.discard(finished)) + @property def recipes_dir(self) -> str: """Get path to recipes directory""" @@ -255,7 +274,45 @@ class RecipeScanner: # Return the cache (may be empty or partially initialized) return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) - + + async def refresh_cache(self, force: bool = False) -> RecipeCache: + """Public helper to refresh or return the recipe cache.""" + + return await self.get_cached_data(force_refresh=force) + + async def add_recipe(self, recipe_data: Dict[str, Any]) -> None: + """Add a recipe to the in-memory cache.""" + + if not recipe_data: + return + + cache = await self.get_cached_data() + await cache.add_recipe(recipe_data, resort=False) + self._schedule_resort() + + async def remove_recipe(self, recipe_id: str) -> bool: + """Remove a recipe from the cache by ID.""" + + if not recipe_id: + return False + + cache = await self.get_cached_data() + removed = await cache.remove_recipe(recipe_id, resort=False) + if removed is None: + return False + + self._schedule_resort() + return True + + async def bulk_remove(self, recipe_ids: Iterable[str]) -> int: + """Remove multiple recipes from the cache.""" + + cache = await self.get_cached_data() + removed = await cache.bulk_remove(recipe_ids, resort=False) + if removed: + self._schedule_resort() + return len(removed) + async def scan_all_recipes(self) -> List[Dict]: """Scan all recipe JSON files and return metadata""" recipes = [] @@ -326,7 +383,6 @@ class RecipeScanner: # Calculate and update fingerprint if missing if 'loras' in recipe_data and 'fingerprint' not in recipe_data: - from ..utils.utils import calculate_recipe_fingerprint fingerprint = calculate_recipe_fingerprint(recipe_data['loras']) recipe_data['fingerprint'] = fingerprint @@ -497,9 +553,36 @@ class RecipeScanner: logger.error(f"Error getting base model for lora: {e}") return None + def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]: + """Populate convenience fields for a LoRA entry.""" + + if not lora or not self._lora_scanner: + return lora + + hash_value = (lora.get('hash') or '').lower() + if not hash_value: + return lora + + try: + lora['inLibrary'] = self._lora_scanner.has_hash(hash_value) + lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value) + lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value) + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Error enriching lora entry %s: %s", hash_value, exc) + + return lora + + async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]: + """Lookup a local LoRA model by name.""" + + if not self._lora_scanner or not name: + return None + + return await self._lora_scanner.get_model_info_by_name(name) + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True): """Get paginated and filtered recipe data - + Args: page: Current page number (1-based) page_size: Number of items per page @@ -598,16 +681,12 @@ class RecipeScanner: # Get paginated items paginated_items = filtered_data[start_idx:end_idx] - + # Add inLibrary information for each lora for item in paginated_items: if 'loras' in item: - for lora in item['loras']: - if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower()) - lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower()) - + item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']] + result = { 'items': paginated_items, 'total': total_items, @@ -653,13 +732,8 @@ class RecipeScanner: # Add lora metadata if 'loras' in formatted_recipe: - for lora in formatted_recipe['loras']: - if 'hash' in lora and lora['hash']: - lora_hash = lora['hash'].lower() - lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash) - lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash) - lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash) - + formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']] + return formatted_recipe def _format_file_url(self, file_path: str) -> str: @@ -717,26 +791,159 @@ class RecipeScanner: # Save updated recipe with open(recipe_json_path, 'w', encoding='utf-8') as f: json.dump(recipe_data, f, indent=4, ensure_ascii=False) - + # Update the cache if it exists if self._cache is not None: - await self._cache.update_recipe_metadata(recipe_id, metadata) - + await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False) + self._schedule_resort() + # If the recipe has an image, update its EXIF metadata from ..utils.exif_utils import ExifUtils image_path = recipe_data.get('file_path') if image_path and os.path.exists(image_path): ExifUtils.append_recipe_metadata(image_path, recipe_data) - + return True except Exception as e: import logging logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True) return False + async def update_lora_entry( + self, + recipe_id: str, + lora_index: int, + *, + target_name: str, + target_lora: Optional[Dict[str, Any]] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Update a specific LoRA entry within a recipe. + + Returns the updated recipe data and the refreshed LoRA metadata. + """ + + if target_name is None: + raise ValueError("target_name must be provided") + + recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + raise RecipeNotFoundError("Recipe not found") + + async with self._mutation_lock: + with open(recipe_json_path, 'r', encoding='utf-8') as file_obj: + recipe_data = json.load(file_obj) + + loras = recipe_data.get('loras', []) + if lora_index >= len(loras): + raise RecipeNotFoundError("LoRA index out of range in recipe") + + lora_entry = loras[lora_index] + lora_entry['isDeleted'] = False + lora_entry['exclude'] = False + lora_entry['file_name'] = target_name + + if target_lora is not None: + sha_value = target_lora.get('sha256') or target_lora.get('sha') + if sha_value: + lora_entry['hash'] = sha_value.lower() + + civitai_info = target_lora.get('civitai') or {} + if civitai_info: + lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '') + lora_entry['modelVersionName'] = civitai_info.get('name', '') + lora_entry['modelVersionId'] = civitai_info.get('id') + + recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', [])) + recipe_data['modified'] = time.time() + + with open(recipe_json_path, 'w', encoding='utf-8') as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + cache = await self.get_cached_data() + replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False) + if not replaced: + await cache.add_recipe(recipe_data, resort=False) + self._schedule_resort() + + updated_lora = dict(lora_entry) + if target_lora is not None: + preview_url = target_lora.get('preview_url') + if preview_url: + updated_lora['preview_url'] = config.get_preview_static_url(preview_url) + if target_lora.get('file_path'): + updated_lora['localPath'] = target_lora['file_path'] + + updated_lora = self._enrich_lora_entry(updated_lora) + return recipe_data, updated_lora + + async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]: + """Return recipes that reference a given LoRA hash.""" + + if not lora_hash: + return [] + + normalized_hash = lora_hash.lower() + cache = await self.get_cached_data() + matching_recipes: List[Dict[str, Any]] = [] + + for recipe in cache.raw_data: + loras = recipe.get('loras', []) + if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras): + recipe_copy = {**recipe} + recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras] + recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path')) + matching_recipes.append(recipe_copy) + + return matching_recipes + + async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]: + """Build LoRA syntax tokens for a recipe.""" + + cache = await self.get_cached_data() + recipe = await cache.get_recipe(recipe_id) + if recipe is None: + raise RecipeNotFoundError("Recipe not found") + + loras = recipe.get('loras', []) + if not loras: + return [] + + lora_cache = None + if self._lora_scanner is not None: + lora_cache = await self._lora_scanner.get_cached_data() + + syntax_parts: List[str] = [] + for lora in loras: + if lora.get('isDeleted', False): + continue + + file_name = None + hash_value = (lora.get('hash') or '').lower() + if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'): + file_path = self._lora_scanner._hash_index.get_path(hash_value) + if file_path: + file_name = os.path.splitext(os.path.basename(file_path))[0] + + if not file_name and lora.get('modelVersionId') and lora_cache is not None: + for cached_lora in getattr(lora_cache, 'raw_data', []): + civitai_info = cached_lora.get('civitai') + if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'): + cached_path = cached_lora.get('path') or cached_lora.get('file_path') + if cached_path: + file_name = os.path.splitext(os.path.basename(cached_path))[0] + break + + if not file_name: + file_name = lora.get('file_name', 'unknown-lora') + + strength = lora.get('strength', 1.0) + syntax_parts.append(f"") + + return syntax_parts + async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]: """Update file_name in all recipes that contain a LoRA with the specified hash. - + Args: hash_value: The SHA256 hash value of the LoRA new_file_name: The new file_name to set diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index 945680df..078ac906 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -1,7 +1,6 @@ """Services encapsulating recipe persistence workflows.""" from __future__ import annotations -import asyncio import base64 import json import os @@ -123,7 +122,7 @@ class RecipePersistenceService: self._exif_utils.append_recipe_metadata(image_path, recipe_data) matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id) - await self._update_cache(recipe_scanner, recipe_data) + await recipe_scanner.add_recipe(recipe_data) return PersistenceResult( { @@ -154,7 +153,7 @@ class RecipePersistenceService: if image_path and os.path.exists(image_path): os.remove(image_path) - await self._remove_from_cache(recipe_scanner, recipe_id) + await recipe_scanner.remove_recipe(recipe_id) return PersistenceResult({"success": True, "message": "Recipe deleted successfully"}) async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult: @@ -185,40 +184,16 @@ class RecipePersistenceService: if not os.path.exists(recipe_path): raise RecipeNotFoundError("Recipe not found") - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - target_lora = None if lora_scanner is None else await lora_scanner.get_model_info_by_name(target_name) + target_lora = await recipe_scanner.get_local_lora(target_name) if not target_lora: raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}") - with open(recipe_path, "r", encoding="utf-8") as file_obj: - recipe_data = json.load(file_obj) - - loras = recipe_data.get("loras", []) - if lora_index >= len(loras): - raise RecipeNotFoundError("LoRA index out of range in recipe") - - lora = loras[lora_index] - lora["isDeleted"] = False - lora["exclude"] = False - lora["file_name"] = target_name - if "sha256" in target_lora: - lora["hash"] = target_lora["sha256"].lower() - if target_lora.get("civitai"): - lora["modelName"] = target_lora["civitai"]["model"]["name"] - lora["modelVersionName"] = target_lora["civitai"]["name"] - lora["modelVersionId"] = target_lora["civitai"]["id"] - - recipe_data["fingerprint"] = calculate_recipe_fingerprint(recipe_data.get("loras", [])) - - with open(recipe_path, "w", encoding="utf-8") as file_obj: - json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) - - updated_lora = dict(lora) - updated_lora["inLibrary"] = True - updated_lora["preview_url"] = config.get_preview_static_url(target_lora["preview_url"]) - updated_lora["localPath"] = target_lora["file_path"] - - await self._refresh_cache_after_update(recipe_scanner, recipe_id, recipe_data) + recipe_data, updated_lora = await recipe_scanner.update_lora_entry( + recipe_id, + lora_index, + target_name=target_name, + target_lora=target_lora, + ) image_path = recipe_data.get("file_path") if image_path and os.path.exists(image_path): @@ -276,7 +251,7 @@ class RecipePersistenceService: failed_recipes.append({"id": recipe_id, "reason": str(exc)}) if deleted_recipes: - await self._bulk_remove_from_cache(recipe_scanner, deleted_recipes) + await recipe_scanner.bulk_remove(deleted_recipes) return PersistenceResult( { @@ -314,14 +289,11 @@ class RecipePersistenceService: if not lora_matches: raise RecipeValidationError("No LoRAs found in the generation metadata") - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) loras_data = [] base_model_counts: Dict[str, int] = {} for name, strength in lora_matches: - lora_info = None - if lora_scanner is not None: - lora_info = await lora_scanner.get_model_info_by_name(name) + lora_info = await recipe_scanner.get_local_lora(name) lora_data = { "file_name": name, "strength": float(strength), @@ -366,7 +338,7 @@ class RecipePersistenceService: json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) self._exif_utils.append_recipe_metadata(image_path, recipe_data) - await self._update_cache(recipe_scanner, recipe_data) + await recipe_scanner.add_recipe(recipe_data) return PersistenceResult( { @@ -422,45 +394,6 @@ class RecipePersistenceService: matches.remove(exclude_id) return matches - async def _update_cache(self, recipe_scanner, recipe_data: dict[str, Any]) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data.append(recipe_data) - asyncio.create_task(cache.resort()) - self._logger.info("Added recipe %s to cache", recipe_data.get("id")) - - async def _remove_from_cache(self, recipe_scanner, recipe_id: str) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data = [item for item in cache.raw_data if str(item.get("id", "")) != recipe_id] - asyncio.create_task(cache.resort()) - self._logger.info("Removed recipe %s from cache", recipe_id) - - async def _bulk_remove_from_cache(self, recipe_scanner, recipe_ids: Iterable[str]) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - recipe_ids_set = set(recipe_ids) - cache.raw_data = [item for item in cache.raw_data if item.get("id") not in recipe_ids_set] - asyncio.create_task(cache.resort()) - self._logger.info("Removed %s recipes from cache", len(recipe_ids_set)) - - async def _refresh_cache_after_update( - self, - recipe_scanner, - recipe_id: str, - recipe_data: dict[str, Any], - ) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - for cache_item in cache.raw_data: - if cache_item.get("id") == recipe_id: - cache_item.update({ - "loras": recipe_data.get("loras", []), - "fingerprint": recipe_data.get("fingerprint"), - }) - asyncio.create_task(cache.resort()) - break - def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str: recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]] recipe_name = "_".join(recipe_name_parts) diff --git a/py/services/recipes/sharing_service.py b/py/services/recipes/sharing_service.py index 7c365bba..47ab9718 100644 --- a/py/services/recipes/sharing_service.py +++ b/py/services/recipes/sharing_service.py @@ -38,11 +38,7 @@ class RecipeSharingService: async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult: """Prepare a temporary downloadable copy of a recipe image.""" - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, - ) + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) if not recipe: raise RecipeNotFoundError("Recipe not found") @@ -81,11 +77,7 @@ class RecipeSharingService: self._cleanup_entry(recipe_id) raise RecipeNotFoundError("Shared recipe file not found") - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, - ) + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) filename_base = ( f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id ) diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py new file mode 100644 index 00000000..63c18f25 --- /dev/null +++ b/tests/services/test_recipe_scanner.py @@ -0,0 +1,185 @@ +import asyncio +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from py.config import config +from py.services.recipe_scanner import RecipeScanner +from py.utils.utils import calculate_recipe_fingerprint + + +class StubHashIndex: + def __init__(self) -> None: + self._hash_to_path: dict[str, str] = {} + + def get_path(self, hash_value: str) -> str | None: + return self._hash_to_path.get(hash_value) + + +class StubLoraScanner: + def __init__(self) -> None: + self._hash_index = StubHashIndex() + self._hash_meta: dict[str, dict[str, str]] = {} + self._models_by_name: dict[str, dict] = {} + self._cache = SimpleNamespace(raw_data=[]) + + async def get_cached_data(self): + return self._cache + + def has_hash(self, hash_value: str) -> bool: + return hash_value.lower() in self._hash_meta + + def get_preview_url_by_hash(self, hash_value: str) -> str: + meta = self._hash_meta.get(hash_value.lower()) + return meta.get("preview_url", "") if meta else "" + + def get_path_by_hash(self, hash_value: str) -> str | None: + meta = self._hash_meta.get(hash_value.lower()) + return meta.get("path") if meta else None + + async def get_model_info_by_name(self, name: str): + return self._models_by_name.get(name) + + def register_model(self, name: str, info: dict) -> None: + self._models_by_name[name] = info + hash_value = (info.get("sha256") or "").lower() + if hash_value: + self._hash_meta[hash_value] = { + "path": info.get("file_path", ""), + "preview_url": info.get("preview_url", ""), + } + self._hash_index._hash_to_path[hash_value] = info.get("file_path", "") + self._cache.raw_data.append({ + "sha256": info.get("sha256", ""), + "path": info.get("file_path", ""), + "civitai": info.get("civitai", {}), + }) + + +@pytest.fixture +def recipe_scanner(tmp_path: Path, monkeypatch): + RecipeScanner._instance = None + monkeypatch.setattr(config, "loras_roots", [str(tmp_path)]) + stub = StubLoraScanner() + scanner = RecipeScanner(lora_scanner=stub) + asyncio.run(scanner.refresh_cache(force=True)) + yield scanner, stub + RecipeScanner._instance = None + + +async def test_add_recipe_during_concurrent_reads(recipe_scanner): + scanner, _ = recipe_scanner + + initial_recipe = { + "id": "one", + "file_path": "path/a.png", + "title": "First", + "modified": 1.0, + "created_date": 1.0, + "loras": [], + } + await scanner.add_recipe(initial_recipe) + + new_recipe = { + "id": "two", + "file_path": "path/b.png", + "title": "Second", + "modified": 2.0, + "created_date": 2.0, + "loras": [], + } + + async def reader_task(): + for _ in range(5): + cache = await scanner.get_cached_data() + _ = [item["id"] for item in cache.raw_data] + await asyncio.sleep(0) + + await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe)) + await asyncio.sleep(0) + cache = await scanner.get_cached_data() + + assert {item["id"] for item in cache.raw_data} == {"one", "two"} + assert len(cache.sorted_by_name) == len(cache.raw_data) + + +async def test_remove_recipe_during_reads(recipe_scanner): + scanner, _ = recipe_scanner + + recipe_ids = ["alpha", "beta", "gamma"] + for index, recipe_id in enumerate(recipe_ids): + await scanner.add_recipe({ + "id": recipe_id, + "file_path": f"path/{recipe_id}.png", + "title": recipe_id, + "modified": float(index), + "created_date": float(index), + "loras": [], + }) + + async def reader_task(): + for _ in range(5): + cache = await scanner.get_cached_data() + _ = list(cache.sorted_by_date) + await asyncio.sleep(0) + + await asyncio.gather(reader_task(), scanner.remove_recipe("beta")) + await asyncio.sleep(0) + cache = await scanner.get_cached_data() + + assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"} + + +async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner): + scanner, stub = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + recipe_id = "recipe-1" + recipe_path = recipes_dir / f"{recipe_id}.recipe.json" + recipe_data = { + "id": recipe_id, + "file_path": str(tmp_path / "image.png"), + "title": "Original", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True}, + ], + } + recipe_path.write_text(json.dumps(recipe_data)) + + await scanner.add_recipe(dict(recipe_data)) + + target_hash = "abc123" + target_info = { + "sha256": target_hash, + "file_path": str(tmp_path / "loras" / "target.safetensors"), + "preview_url": "preview.png", + "civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}}, + } + stub.register_model("target", target_info) + + updated_recipe, updated_lora = await scanner.update_lora_entry( + recipe_id, + 0, + target_name="target", + target_lora=target_info, + ) + + assert updated_lora["inLibrary"] is True + assert updated_lora["localPath"] == target_info["file_path"] + assert updated_lora["hash"] == target_hash + + with recipe_path.open("r", encoding="utf-8") as file_obj: + persisted = json.load(file_obj) + + expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"]) + assert persisted["fingerprint"] == expected_fingerprint + + cache = await scanner.get_cached_data() + cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id) + assert cached_recipe["loras"][0]["hash"] == target_hash + assert cached_recipe["fingerprint"] == expected_fingerprint diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index e57abf2f..81a15424 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -108,6 +108,10 @@ async def test_save_recipe_reports_duplicates(tmp_path): self.last_fingerprint = fingerprint return ["existing"] + async def add_recipe(self, recipe_data): + self._cache.raw_data.append(recipe_data) + await self._cache.resort() + scanner = DummyScanner(tmp_path) service = RecipePersistenceService( exif_utils=exif_utils, From f7cffd2eba724d3b419f01db62fa5c28efe8c642 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 22 Sep 2025 14:15:24 +0800 Subject: [PATCH 14/24] test(recipes): add route smoke tests and docs --- README.md | 26 +++ docs/architecture/recipe_routes.md | 103 ++++++--- tests/routes/test_recipe_routes.py | 330 +++++++++++++++++++++++++++++ 3 files changed, 427 insertions(+), 32 deletions(-) create mode 100644 tests/routes/test_recipe_routes.py diff --git a/README.md b/README.md index 7e932acc..f9c1806b 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,32 @@ You can now run LoRA Manager independently from ComfyUI: This standalone mode provides a lightweight option for managing your model and recipe collection without needing to run the full ComfyUI environment, making it useful even for users who primarily use other stable diffusion interfaces. +## Developer notes + +The REST layer is split into modular registrars, controllers, and handler sets +to simplify maintenance: + +* `py/routes/recipe_route_registrar.py` holds the declarative endpoint list. +* `py/routes/base_recipe_routes.py` wires shared services/templates and returns + the handler mapping consumed by `RecipeRouteRegistrar`. +* `py/routes/handlers/recipe_handlers.py` groups HTTP adapters by concern (page + rendering, listings, queries, mutations, sharing) and delegates business rules + to services in `py/services/recipes/`. + +To add a new recipe endpoint: + +1. Declare the route in `ROUTE_DEFINITIONS` with a unique handler name. +2. Implement the coroutine on the appropriate handler class or introduce a new + handler when the concern does not fit existing ones. +3. Inject additional collaborators in + `BaseRecipeRoutes._create_handler_set` (for example a new service or factory) + so the handler can access its dependencies. + +The end-to-end wiring is documented in +[`docs/architecture/recipe_routes.md`](docs/architecture/recipe_routes.md), and +the integration suite in `tests/routes/test_recipe_routes.py` smoke-tests the +primary endpoints. + --- ## Contributing diff --git a/docs/architecture/recipe_routes.md b/docs/architecture/recipe_routes.md index 28684fad..0bdb7c90 100644 --- a/docs/architecture/recipe_routes.md +++ b/docs/architecture/recipe_routes.md @@ -1,50 +1,89 @@ -# Recipe route scaffolding +# Recipe route architecture -The recipe HTTP stack is being migrated to mirror the shared model routing -architecture. The first phase extracts the registrar/controller scaffolding so -future handler sets can plug into a stable surface area. The stack now mirrors -the same separation of concerns described in -`docs/architecture/model_routes.md`: +The recipe routing stack now mirrors the modular model route design. HTTP +bindings, controller wiring, handler orchestration, and business rules live in +separate layers so new behaviours can be added without re-threading the entire +feature. The diagram below outlines the flow for a typical request: ```mermaid graph TD subgraph HTTP - A[RecipeRouteRegistrar] -->|binds| B[BaseRecipeRoutes handler owner] + A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller] end subgraph Application - B --> C[Recipe handler set] - C --> D[Async handlers] - D --> E[Services / scanners] + B --> C[RecipeHandlerSet] + C --> D1[Handlers] + D1 --> E1[Use cases] + E1 --> F1[Services / scanners] + end + subgraph Side Effects + F1 --> G1[Cache & fingerprint index] + F1 --> G2[Metadata files] + F1 --> G3[Temporary shares] end ``` -## Responsibilities +## Layer responsibilities | Layer | Module(s) | Responsibility | | --- | --- | --- | -| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper that binds them to an `aiohttp` application. | -| Base controller | `py/routes/base_recipe_routes.py` | Lazily resolves shared services, registers the server-side i18n filter exactly once, pre-warms caches on startup, and exposes a `{handler_name: coroutine}` mapping used by the registrar. | -| Handler set (upcoming) | `py/routes/handlers/recipe_handlers.py` (planned) | Will group HTTP handlers by concern (page rendering, listings, mutations, queries, sharing) and surface them to `BaseRecipeRoutes.get_handler_owner()`. | +| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. | +| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. | +| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. | +| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. | -`RecipeRoutes` subclasses the base controller to keep compatibility with the -existing monolithic handlers. Once the handler set is extracted the subclass -will simply provide the concrete owner returned by `get_handler_owner()`. +## Handler responsibilities & invariants -## High-level test baseline +`RecipeHandlerSet` flattens purpose-built handler objects into the callables the +registrar binds. Each handler is responsible for a narrow concern and enforces a +set of invariants before returning: -The new smoke suite in `tests/routes/test_recipe_route_scaffolding.py` -guarantees the registrar/controller contract remains intact: +| Handler | Key endpoints | Collaborators | Contracts | +| --- | --- | --- | --- | +| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. | +| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. | +| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. | +| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. | +| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. | +| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. | -* `BaseRecipeRoutes.attach_dependencies` resolves registry services only once - and protects the i18n filter from duplicate registration. -* Startup hooks are appended exactly once so cache pre-warming and dependency - resolution run during application boot. -* `BaseRecipeRoutes.to_route_mapping()` uses the handler owner as the source of - callables, enabling the upcoming handler set without touching the registrar. -* `RecipeRouteRegistrar` binds every declarative route to the aiohttp router. -* `RecipeRoutes.setup_routes` wires the registrar and startup hooks together so - future refactors can swap in the handler set without editing callers. +## Use case boundaries + +The dedicated services encapsulate long-running work so handlers stay thin. + +| Use case | Entry point | Dependencies | Guarantees | +| --- | --- | --- | --- | +| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. | +| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. | +| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. | + +## Maintaining critical invariants + +* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call + back into the recipe scanner to mutate the in-memory cache and fingerprint + index before returning a response. Tests assert that these methods are invoked + even when stubbing persistence. +* **Fingerprint management** – `RecipePersistenceService` recomputes + fingerprints whenever LoRA metadata changes and duplicate lookups use those + fingerprints to group recipes. Handlers bubble the resulting IDs so clients + can merge duplicates without an extra fetch. +* **Metadata synchronisation** – Saving or reconnecting a recipe updates the + JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the + scanner to resort its cache. Sharing relies on this metadata to generate + filenames and ensure downloads stay in sync with on-disk state. + +## Extending the stack + +1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name. +2. Implement the coroutine on an existing handler or introduce a new handler + class inside `py/routes/handlers/recipe_handlers.py` when the concern does + not fit existing ones. +3. Wire additional collaborators inside + `BaseRecipeRoutes._create_handler_set` (inject new services or factories) and + expose helper getters on the handler owner if the handler needs to share + utilities. + +Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing, +mutation, analysis-error, and sharing paths end-to-end, ensuring the controller +and handler wiring remains valid as new capabilities are added. -These guardrails mirror the expectations in the model route architecture and -provide confidence that future refactors can focus on handlers and use cases -without breaking HTTP wiring. diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py new file mode 100644 index 00000000..467cb5b5 --- /dev/null +++ b/tests/routes/test_recipe_routes.py @@ -0,0 +1,330 @@ +"""Integration smoke tests for the recipe route stack.""" +from __future__ import annotations + +import json +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any, AsyncIterator, Dict, List, Optional + +from aiohttp import FormData, web +from aiohttp.test_utils import TestClient, TestServer + +from py.config import config +from py.routes import base_recipe_routes +from py.routes.recipe_routes import RecipeRoutes +from py.services.recipes import RecipeValidationError +from py.services.service_registry import ServiceRegistry + + +@dataclass +class RecipeRouteHarness: + """Container exposing the aiohttp client and stubbed collaborators.""" + + client: TestClient + scanner: "StubRecipeScanner" + analysis: "StubAnalysisService" + persistence: "StubPersistenceService" + sharing: "StubSharingService" + tmp_dir: Path + + +class StubRecipeScanner: + """Minimal scanner double with the surface used by the handlers.""" + + def __init__(self, base_dir: Path) -> None: + self.recipes_dir = str(base_dir / "recipes") + self.listing_items: List[Dict[str, Any]] = [] + self.cached_raw: List[Dict[str, Any]] = [] + self.recipes: Dict[str, Dict[str, Any]] = {} + self.removed: List[str] = [] + + async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner + return None + + self._lora_scanner = SimpleNamespace( # mimic BaseRecipeRoutes expectations + get_cached_data=_noop_get_cached_data, + _hash_index=SimpleNamespace(_hash_to_path={}), + ) + + async def get_cached_data(self, force_refresh: bool = False) -> SimpleNamespace: # noqa: ARG002 - flag unused by stub + return SimpleNamespace(raw_data=list(self.cached_raw)) + + async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: + items = [dict(item) for item in self.listing_items] + page = int(params.get("page", 1)) + page_size = int(params.get("page_size", 20)) + return { + "items": items, + "total": len(items), + "page": page, + "page_size": page_size, + "total_pages": max(1, (len(items) + page_size - 1) // max(page_size, 1)), + } + + async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]: + return self.recipes.get(recipe_id) + + async def remove_recipe(self, recipe_id: str) -> None: + self.removed.append(recipe_id) + self.recipes.pop(recipe_id, None) + + +class StubAnalysisService: + """Captures calls made by analysis routes while returning canned responses.""" + + instances: List["StubAnalysisService"] = [] + + def __init__(self, **_: Any) -> None: + self.raise_for_uploaded: Optional[Exception] = None + self.raise_for_remote: Optional[Exception] = None + self.raise_for_local: Optional[Exception] = None + self.upload_calls: List[bytes] = [] + self.remote_calls: List[Optional[str]] = [] + self.local_calls: List[Optional[str]] = [] + self.result = SimpleNamespace(payload={"loras": []}, status=200) + StubAnalysisService.instances.append(self) + + async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature + if self.raise_for_uploaded: + raise self.raise_for_uploaded + self.upload_calls.append(image_bytes or b"") + return self.result + + async def analyze_remote_image(self, *, url: Optional[str], recipe_scanner, civitai_client) -> SimpleNamespace: # noqa: D401 + if self.raise_for_remote: + raise self.raise_for_remote + self.remote_calls.append(url) + return self.result + + async def analyze_local_image(self, *, file_path: Optional[str], recipe_scanner) -> SimpleNamespace: # noqa: D401 + if self.raise_for_local: + raise self.raise_for_local + self.local_calls.append(file_path) + return self.result + + async def analyze_widget_metadata(self, *, recipe_scanner) -> SimpleNamespace: + return SimpleNamespace(payload={"metadata": {}, "image_bytes": b""}, status=200) + + +class StubPersistenceService: + """Stub for persistence operations to avoid filesystem writes.""" + + instances: List["StubPersistenceService"] = [] + + def __init__(self, **_: Any) -> None: + self.save_calls: List[Dict[str, Any]] = [] + self.delete_calls: List[str] = [] + self.save_result = SimpleNamespace(payload={"success": True, "recipe_id": "stub-id"}, status=200) + self.delete_result = SimpleNamespace(payload={"success": True}, status=200) + StubPersistenceService.instances.append(self) + + async def save_recipe(self, *, recipe_scanner, image_bytes, image_base64, name, tags, metadata) -> SimpleNamespace: # noqa: D401 + self.save_calls.append( + { + "recipe_scanner": recipe_scanner, + "image_bytes": image_bytes, + "image_base64": image_base64, + "name": name, + "tags": list(tags), + "metadata": metadata, + } + ) + return self.save_result + + async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + self.delete_calls.append(recipe_id) + await recipe_scanner.remove_recipe(recipe_id) + return self.delete_result + + async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]) -> SimpleNamespace: # pragma: no cover - unused by smoke tests + return SimpleNamespace(payload={"success": True, "recipe_id": recipe_id, "updates": updates}, status=200) + + async def reconnect_lora(self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace(payload={"success": True}, status=200) + + async def bulk_delete(self, *, recipe_scanner, recipe_ids: List[str]) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace(payload={"success": True, "deleted": recipe_ids}, status=200) + + async def save_recipe_from_widget(self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace(payload={"success": True}, status=200) + + +class StubSharingService: + """Share service stub recording requests and returning canned responses.""" + + instances: List["StubSharingService"] = [] + + def __init__(self, *, ttl_seconds: int = 300, logger) -> None: # noqa: ARG002 - ttl unused in stub + self.share_calls: List[str] = [] + self.download_calls: List[str] = [] + self.share_result = SimpleNamespace( + payload={"success": True, "download_url": "/share/stub", "filename": "recipe.png"}, + status=200, + ) + self.download_info = SimpleNamespace(file_path="", download_filename="") + StubSharingService.instances.append(self) + + async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + self.share_calls.append(recipe_id) + return self.share_result + + async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + self.download_calls.append(recipe_id) + return self.download_info + + +@asynccontextmanager +async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]: + """Context manager that yields a fully wired recipe route harness.""" + + StubAnalysisService.instances.clear() + StubPersistenceService.instances.clear() + StubSharingService.instances.clear() + + scanner = StubRecipeScanner(tmp_path) + + async def fake_get_recipe_scanner(): + return scanner + + async def fake_get_civitai_client(): + return object() + + monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner) + monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client) + monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService) + monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService) + monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService) + monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False) + + app = web.Application() + RecipeRoutes.setup_routes(app) + + server = TestServer(app) + client = TestClient(server) + await client.start_server() + + harness = RecipeRouteHarness( + client=client, + scanner=scanner, + analysis=StubAnalysisService.instances[-1], + persistence=StubPersistenceService.instances[-1], + sharing=StubSharingService.instances[-1], + tmp_dir=tmp_path, + ) + + try: + yield harness + finally: + await client.close() + StubAnalysisService.instances.clear() + StubPersistenceService.instances.clear() + StubSharingService.instances.clear() + + +async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + recipe_path = harness.tmp_dir / "recipes" / "demo.png" + harness.scanner.listing_items = [ + { + "id": "recipe-1", + "file_path": str(recipe_path), + "title": "Demo", + "loras": [], + } + ] + harness.scanner.cached_raw = list(harness.scanner.listing_items) + + response = await harness.client.get("/api/lm/recipes") + payload = await response.json() + + assert response.status == 200 + assert payload["items"][0]["file_url"].endswith("demo.png") + assert payload["items"][0]["loras"] == [] + + +async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + form = FormData() + form.add_field("image", b"stub", filename="sample.png", content_type="image/png") + form.add_field("name", "Test Recipe") + form.add_field("tags", json.dumps(["tag-a"])) + form.add_field("metadata", json.dumps({"loras": []})) + form.add_field("image_base64", "aW1hZ2U=") + + harness.persistence.save_result = SimpleNamespace( + payload={"success": True, "recipe_id": "saved-id"}, + status=201, + ) + + save_response = await harness.client.post("/api/lm/recipes/save", data=form) + save_payload = await save_response.json() + + assert save_response.status == 201 + assert save_payload["recipe_id"] == "saved-id" + assert harness.persistence.save_calls[-1]["name"] == "Test Recipe" + + harness.persistence.delete_result = SimpleNamespace(payload={"success": True}, status=200) + + delete_response = await harness.client.delete("/api/lm/recipe/saved-id") + delete_payload = await delete_response.json() + + assert delete_response.status == 200 + assert delete_payload["success"] is True + assert harness.persistence.delete_calls == ["saved-id"] + + +async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided") + + form = FormData() + form.add_field("image", b"", filename="empty.png", content_type="image/png") + + response = await harness.client.post("/api/lm/recipes/analyze-image", data=form) + payload = await response.json() + + assert response.status == 400 + assert payload["error"] == "No image data provided" + assert payload["loras"] == [] + + +async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + recipe_id = "share-me" + download_path = harness.tmp_dir / "recipes" / "share.png" + download_path.parent.mkdir(parents=True, exist_ok=True) + download_path.write_bytes(b"stub") + + harness.scanner.recipes[recipe_id] = { + "id": recipe_id, + "title": "Shared", + "file_path": str(download_path), + } + + harness.sharing.share_result = SimpleNamespace( + payload={"success": True, "download_url": "/api/share", "filename": "share.png"}, + status=200, + ) + harness.sharing.download_info = SimpleNamespace( + file_path=str(download_path), + download_filename="share.png", + ) + + share_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share") + share_payload = await share_response.json() + + assert share_response.status == 200 + assert share_payload["filename"] == "share.png" + assert harness.sharing.share_calls == [recipe_id] + + download_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share/download") + body = await download_response.read() + + assert download_response.status == 200 + assert download_response.headers["Content-Disposition"] == 'attachment; filename="share.png"' + assert body == b"stub" + + download_path.unlink(missing_ok=True) + From e0aba6c49abe973acbf2ead5a5d5a5a8beaeb804 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 10:41:56 +0800 Subject: [PATCH 15/24] test(example-images): add route regression harness --- tests/routes/test_example_images_routes.py | 220 +++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 tests/routes/test_example_images_routes.py diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py new file mode 100644 index 00000000..d64e1d7f --- /dev/null +++ b/tests/routes/test_example_images_routes.py @@ -0,0 +1,220 @@ +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, List, Tuple + +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer +import pytest + +from py.routes import example_images_routes +from py.routes.example_images_routes import ExampleImagesRoutes + + +@dataclass +class ExampleImagesHarness: + """Container exposing the aiohttp client and stubbed collaborators.""" + + client: TestClient + download_manager: Any + processor: Any + file_manager: Any + + +@asynccontextmanager +async def example_images_app(monkeypatch: pytest.MonkeyPatch) -> ExampleImagesHarness: + """Yield an ExampleImagesRoutes app wired with stubbed collaborators.""" + + class StubDownloadManager: + calls: List[Tuple[str, Any]] = [] + + @staticmethod + async def start_download(request): + payload = await request.json() + StubDownloadManager.calls.append(("start_download", payload)) + return web.json_response({"operation": "start_download", "payload": payload}) + + @staticmethod + async def get_status(request): + StubDownloadManager.calls.append(("get_status", dict(request.query))) + return web.json_response({"operation": "get_status"}) + + @staticmethod + async def pause_download(request): + StubDownloadManager.calls.append(("pause_download", None)) + return web.json_response({"operation": "pause_download"}) + + @staticmethod + async def resume_download(request): + StubDownloadManager.calls.append(("resume_download", None)) + return web.json_response({"operation": "resume_download"}) + + @staticmethod + async def start_force_download(request): + payload = await request.json() + StubDownloadManager.calls.append(("start_force_download", payload)) + return web.json_response({"operation": "start_force_download", "payload": payload}) + + class StubExampleImagesProcessor: + calls: List[Tuple[str, Any]] = [] + + @staticmethod + async def import_images(request): + payload = await request.json() + StubExampleImagesProcessor.calls.append(("import_images", payload)) + return web.json_response({"operation": "import_images", "payload": payload}) + + @staticmethod + async def delete_custom_image(request): + payload = await request.json() + StubExampleImagesProcessor.calls.append(("delete_custom_image", payload)) + return web.json_response({"operation": "delete_custom_image", "payload": payload}) + + class StubExampleImagesFileManager: + calls: List[Tuple[str, Any]] = [] + + @staticmethod + async def open_folder(request): + payload = await request.json() + StubExampleImagesFileManager.calls.append(("open_folder", payload)) + return web.json_response({"operation": "open_folder", "payload": payload}) + + @staticmethod + async def get_files(request): + StubExampleImagesFileManager.calls.append(("get_files", dict(request.query))) + return web.json_response({"operation": "get_files", "query": dict(request.query)}) + + @staticmethod + async def has_images(request): + StubExampleImagesFileManager.calls.append(("has_images", dict(request.query))) + return web.json_response({"operation": "has_images", "query": dict(request.query)}) + + monkeypatch.setattr(example_images_routes, "DownloadManager", StubDownloadManager) + monkeypatch.setattr(example_images_routes, "ExampleImagesProcessor", StubExampleImagesProcessor) + monkeypatch.setattr(example_images_routes, "ExampleImagesFileManager", StubExampleImagesFileManager) + + app = web.Application() + ExampleImagesRoutes.setup_routes(app) + + server = TestServer(app) + client = TestClient(server) + await client.start_server() + + try: + yield ExampleImagesHarness( + client=client, + download_manager=StubDownloadManager, + processor=StubExampleImagesProcessor, + file_manager=StubExampleImagesFileManager, + ) + finally: + await client.close() + + +@pytest.mark.parametrize( + "endpoint, payload", + [ + ("/api/lm/download-example-images", {"model_types": ["lora"], "optimize": False}), + ("/api/lm/force-download-example-images", {"model_hashes": ["abc123"]}), + ], +) +async def test_download_routes_delegate_to_manager(endpoint, payload, monkeypatch: pytest.MonkeyPatch): + async with example_images_app(monkeypatch) as harness: + response = await harness.client.post(endpoint, json=payload) + body = await response.json() + + assert response.status == 200 + assert body["payload"] == payload + assert body["operation"].startswith("start") + + expected_call = body["operation"], payload + assert expected_call in harness.download_manager.calls + + +async def test_status_route_returns_manager_payload(monkeypatch: pytest.MonkeyPatch): + async with example_images_app(monkeypatch) as harness: + response = await harness.client.get( + "/api/lm/example-images-status", params={"detail": "true"} + ) + body = await response.json() + + assert response.status == 200 + assert body == {"operation": "get_status"} + assert harness.download_manager.calls == [("get_status", {"detail": "true"})] + + +async def test_pause_and_resume_routes_delegate(monkeypatch: pytest.MonkeyPatch): + async with example_images_app(monkeypatch) as harness: + pause_response = await harness.client.post("/api/lm/pause-example-images") + resume_response = await harness.client.post("/api/lm/resume-example-images") + + assert pause_response.status == 200 + assert await pause_response.json() == {"operation": "pause_download"} + assert resume_response.status == 200 + assert await resume_response.json() == {"operation": "resume_download"} + + assert harness.download_manager.calls[-2:] == [ + ("pause_download", None), + ("resume_download", None), + ] + + +async def test_import_route_delegates_to_processor(monkeypatch: pytest.MonkeyPatch): + payload = {"model_hash": "abc123", "files": ["/path/image.png"]} + async with example_images_app(monkeypatch) as harness: + response = await harness.client.post( + "/api/lm/import-example-images", json=payload + ) + body = await response.json() + + assert response.status == 200 + assert body == {"operation": "import_images", "payload": payload} + assert harness.processor.calls == [("import_images", payload)] + + +async def test_delete_route_delegates_to_processor(monkeypatch: pytest.MonkeyPatch): + payload = {"model_hash": "abc123", "short_id": "xyz"} + async with example_images_app(monkeypatch) as harness: + response = await harness.client.post( + "/api/lm/delete-example-image", json=payload + ) + body = await response.json() + + assert response.status == 200 + assert body == {"operation": "delete_custom_image", "payload": payload} + assert harness.processor.calls == [("delete_custom_image", payload)] + + +async def test_file_routes_delegate_to_file_manager(monkeypatch: pytest.MonkeyPatch): + open_payload = {"model_hash": "abc123"} + files_params = {"model_hash": "def456"} + + async with example_images_app(monkeypatch) as harness: + open_response = await harness.client.post( + "/api/lm/open-example-images-folder", json=open_payload + ) + files_response = await harness.client.get( + "/api/lm/example-image-files", params=files_params + ) + has_response = await harness.client.get( + "/api/lm/has-example-images", params=files_params + ) + + assert open_response.status == 200 + assert files_response.status == 200 + assert has_response.status == 200 + + assert await open_response.json() == {"operation": "open_folder", "payload": open_payload} + assert await files_response.json() == { + "operation": "get_files", + "query": files_params, + } + assert await has_response.json() == { + "operation": "has_images", + "query": files_params, + } + + assert harness.file_manager.calls == [ + ("open_folder", open_payload), + ("get_files", files_params), + ("has_images", files_params), + ] From 613cd81152302bde5bf534f4014291a5bb02e50d Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 11:12:05 +0800 Subject: [PATCH 16/24] refactor(routes): add registrar for example images --- py/routes/example_images_route_registrar.py | 61 +++++++++++++++++++++ py/routes/example_images_routes.py | 46 ++++++++++------ tests/routes/test_example_images_routes.py | 16 +++++- 3 files changed, 106 insertions(+), 17 deletions(-) create mode 100644 py/routes/example_images_route_registrar.py diff --git a/py/routes/example_images_route_registrar.py b/py/routes/example_images_route_registrar.py new file mode 100644 index 00000000..d0f1fab0 --- /dev/null +++ b/py/routes/example_images_route_registrar.py @@ -0,0 +1,61 @@ +"""Route registrar for example image endpoints.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable, Mapping + +from aiohttp import web + + +@dataclass(frozen=True) +class RouteDefinition: + """Declarative configuration for a HTTP route.""" + + method: str + path: str + handler_name: str + + +ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( + RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"), + RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"), + RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"), + RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"), + RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"), + RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"), + RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"), + RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"), + RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"), + RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"), +) + + +class ExampleImagesRouteRegistrar: + """Bind declarative example image routes to an aiohttp router.""" + + _METHOD_MAP = { + "GET": "add_get", + "POST": "add_post", + "PUT": "add_put", + "DELETE": "add_delete", + } + + def __init__(self, app: web.Application) -> None: + self._app = app + + def register_routes( + self, + handler_lookup: Mapping[str, Callable[[web.Request], object]], + *, + definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS, + ) -> None: + """Register each route definition using the supplied handlers.""" + + for definition in definitions: + handler = handler_lookup[definition.handler_name] + self._bind_route(definition.method, definition.path, handler) + + def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None: + add_method_name = self._METHOD_MAP[method.upper()] + add_method = getattr(self._app.router, add_method_name) + add_method(path, handler) diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 07cb0e71..193cfe1d 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -1,4 +1,9 @@ import logging +from typing import Callable + +from aiohttp import web + +from .example_images_route_registrar import ExampleImagesRouteRegistrar from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_file_manager import ExampleImagesFileManager @@ -6,22 +11,31 @@ from ..services.websocket_manager import ws_manager logger = logging.getLogger(__name__) + class ExampleImagesRoutes: """Routes for example images related functionality""" - + @staticmethod - def setup_routes(app): - """Register example images routes""" - app.router.add_post('/api/lm/download-example-images', ExampleImagesRoutes.download_example_images) - app.router.add_post('/api/lm/import-example-images', ExampleImagesRoutes.import_example_images) - app.router.add_get('/api/lm/example-images-status', ExampleImagesRoutes.get_example_images_status) - app.router.add_post('/api/lm/pause-example-images', ExampleImagesRoutes.pause_example_images) - app.router.add_post('/api/lm/resume-example-images', ExampleImagesRoutes.resume_example_images) - app.router.add_post('/api/lm/open-example-images-folder', ExampleImagesRoutes.open_example_images_folder) - app.router.add_get('/api/lm/example-image-files', ExampleImagesRoutes.get_example_image_files) - app.router.add_get('/api/lm/has-example-images', ExampleImagesRoutes.has_example_images) - app.router.add_post('/api/lm/delete-example-image', ExampleImagesRoutes.delete_example_image) - app.router.add_post('/api/lm/force-download-example-images', ExampleImagesRoutes.force_download_example_images) + def setup_routes(app: web.Application) -> None: + """Register example images routes using the registrar.""" + + registrar = ExampleImagesRouteRegistrar(app) + registrar.register_routes(ExampleImagesRoutes._route_mapping()) + + @staticmethod + def _route_mapping() -> dict[str, Callable[[web.Request], object]]: + return { + "download_example_images": ExampleImagesRoutes.download_example_images, + "import_example_images": ExampleImagesRoutes.import_example_images, + "get_example_images_status": ExampleImagesRoutes.get_example_images_status, + "pause_example_images": ExampleImagesRoutes.pause_example_images, + "resume_example_images": ExampleImagesRoutes.resume_example_images, + "open_example_images_folder": ExampleImagesRoutes.open_example_images_folder, + "get_example_image_files": ExampleImagesRoutes.get_example_image_files, + "has_example_images": ExampleImagesRoutes.has_example_images, + "delete_example_image": ExampleImagesRoutes.delete_example_image, + "force_download_example_images": ExampleImagesRoutes.force_download_example_images, + } @staticmethod async def download_example_images(request): @@ -42,7 +56,7 @@ class ExampleImagesRoutes: async def resume_example_images(request): """Resume the example images download""" return await DownloadManager.resume_download(request) - + @staticmethod async def open_example_images_folder(request): """Open the example images folder for a specific model""" @@ -57,7 +71,7 @@ class ExampleImagesRoutes: async def import_example_images(request): """Import local example images for a model""" return await ExampleImagesProcessor.import_images(request) - + @staticmethod async def has_example_images(request): """Check if example images folder exists and is not empty for a model""" @@ -71,4 +85,4 @@ class ExampleImagesRoutes: @staticmethod async def force_download_example_images(request): """Force download example images for specific models""" - return await DownloadManager.start_force_download(request) \ No newline at end of file + return await DownloadManager.start_force_download(request) diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index d64e1d7f..dac40ee9 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -1,6 +1,6 @@ from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, List, Tuple +from typing import Any, List, Set, Tuple from aiohttp import web from aiohttp.test_utils import TestClient, TestServer @@ -8,6 +8,7 @@ import pytest from py.routes import example_images_routes from py.routes.example_images_routes import ExampleImagesRoutes +from py.routes.example_images_route_registrar import ROUTE_DEFINITIONS @dataclass @@ -110,6 +111,19 @@ async def example_images_app(monkeypatch: pytest.MonkeyPatch) -> ExampleImagesHa await client.close() +async def test_setup_routes_registers_all_definitions(monkeypatch: pytest.MonkeyPatch): + async with example_images_app(monkeypatch) as harness: + registered: Set[tuple[str, str]] = { + (route.method, route.resource.canonical) + for route in harness.client.app.router.routes() + if route.resource.canonical + } + + expected = {(definition.method, definition.path) for definition in ROUTE_DEFINITIONS} + + assert expected <= registered + + @pytest.mark.parametrize( "endpoint, payload", [ From 85f79cd8d184028908e3e95bba8ec593be4330ad Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 11:12:08 +0800 Subject: [PATCH 17/24] refactor(routes): introduce example images controller --- py/routes/example_images_routes.py | 121 +++---- py/routes/handlers/example_images_handlers.py | 83 +++++ tests/routes/test_example_images_routes.py | 329 +++++++++++++----- 3 files changed, 371 insertions(+), 162 deletions(-) create mode 100644 py/routes/handlers/example_images_handlers.py diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 193cfe1d..829760c2 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -1,88 +1,69 @@ +from __future__ import annotations + import logging -from typing import Callable +from typing import Callable, Mapping from aiohttp import web from .example_images_route_registrar import ExampleImagesRouteRegistrar +from .handlers.example_images_handlers import ( + ExampleImagesDownloadHandler, + ExampleImagesFileHandler, + ExampleImagesHandlerSet, + ExampleImagesManagementHandler, +) from ..utils.example_images_download_manager import DownloadManager -from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_file_manager import ExampleImagesFileManager -from ..services.websocket_manager import ws_manager +from ..utils.example_images_processor import ExampleImagesProcessor logger = logging.getLogger(__name__) class ExampleImagesRoutes: - """Routes for example images related functionality""" + """Route controller for example image endpoints.""" - @staticmethod - def setup_routes(app: web.Application) -> None: - """Register example images routes using the registrar.""" + def __init__( + self, + *, + download_manager=DownloadManager, + processor=ExampleImagesProcessor, + file_manager=ExampleImagesFileManager, + ) -> None: + self._download_manager = download_manager + self._processor = processor + self._file_manager = file_manager + self._handler_set: ExampleImagesHandlerSet | None = None + self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None + + @classmethod + def setup_routes(cls, app: web.Application) -> None: + """Register routes on the given aiohttp application using default wiring.""" + + controller = cls() + controller.register(app) + + def register(self, app: web.Application) -> None: + """Bind the controller's handlers to the aiohttp router.""" registrar = ExampleImagesRouteRegistrar(app) - registrar.register_routes(ExampleImagesRoutes._route_mapping()) + registrar.register_routes(self.to_route_mapping()) - @staticmethod - def _route_mapping() -> dict[str, Callable[[web.Request], object]]: - return { - "download_example_images": ExampleImagesRoutes.download_example_images, - "import_example_images": ExampleImagesRoutes.import_example_images, - "get_example_images_status": ExampleImagesRoutes.get_example_images_status, - "pause_example_images": ExampleImagesRoutes.pause_example_images, - "resume_example_images": ExampleImagesRoutes.resume_example_images, - "open_example_images_folder": ExampleImagesRoutes.open_example_images_folder, - "get_example_image_files": ExampleImagesRoutes.get_example_image_files, - "has_example_images": ExampleImagesRoutes.has_example_images, - "delete_example_image": ExampleImagesRoutes.delete_example_image, - "force_download_example_images": ExampleImagesRoutes.force_download_example_images, - } + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: + """Return the registrar-compatible mapping of handler names to callables.""" - @staticmethod - async def download_example_images(request): - """Download example images for models from Civitai""" - return await DownloadManager.start_download(request) + if self._handler_mapping is None: + handler_set = self._build_handler_set() + self._handler_set = handler_set + self._handler_mapping = handler_set.to_route_mapping() + return self._handler_mapping - @staticmethod - async def get_example_images_status(request): - """Get the current status of example images download""" - return await DownloadManager.get_status(request) - - @staticmethod - async def pause_example_images(request): - """Pause the example images download""" - return await DownloadManager.pause_download(request) - - @staticmethod - async def resume_example_images(request): - """Resume the example images download""" - return await DownloadManager.resume_download(request) - - @staticmethod - async def open_example_images_folder(request): - """Open the example images folder for a specific model""" - return await ExampleImagesFileManager.open_folder(request) - - @staticmethod - async def get_example_image_files(request): - """Get list of example image files for a specific model""" - return await ExampleImagesFileManager.get_files(request) - - @staticmethod - async def import_example_images(request): - """Import local example images for a model""" - return await ExampleImagesProcessor.import_images(request) - - @staticmethod - async def has_example_images(request): - """Check if example images folder exists and is not empty for a model""" - return await ExampleImagesFileManager.has_images(request) - - @staticmethod - async def delete_example_image(request): - """Delete a custom example image for a model""" - return await ExampleImagesProcessor.delete_custom_image(request) - - @staticmethod - async def force_download_example_images(request): - """Force download example images for specific models""" - return await DownloadManager.start_force_download(request) + def _build_handler_set(self) -> ExampleImagesHandlerSet: + logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager) + download_handler = ExampleImagesDownloadHandler(self._download_manager) + management_handler = ExampleImagesManagementHandler(self._processor) + file_handler = ExampleImagesFileHandler(self._file_manager) + return ExampleImagesHandlerSet( + download=download_handler, + management=management_handler, + files=file_handler, + ) diff --git a/py/routes/handlers/example_images_handlers.py b/py/routes/handlers/example_images_handlers.py new file mode 100644 index 00000000..3d960338 --- /dev/null +++ b/py/routes/handlers/example_images_handlers.py @@ -0,0 +1,83 @@ +"""Handler set for example image routes.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Mapping + +from aiohttp import web + + +class ExampleImagesDownloadHandler: + """HTTP adapters for download-related example image endpoints.""" + + def __init__(self, download_manager) -> None: + self._download_manager = download_manager + + async def download_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._download_manager.start_download(request) + + async def get_example_images_status(self, request: web.Request) -> web.StreamResponse: + return await self._download_manager.get_status(request) + + async def pause_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._download_manager.pause_download(request) + + async def resume_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._download_manager.resume_download(request) + + async def force_download_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._download_manager.start_force_download(request) + + +class ExampleImagesManagementHandler: + """HTTP adapters for import/delete endpoints.""" + + def __init__(self, processor) -> None: + self._processor = processor + + async def import_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._processor.import_images(request) + + async def delete_example_image(self, request: web.Request) -> web.StreamResponse: + return await self._processor.delete_custom_image(request) + + +class ExampleImagesFileHandler: + """HTTP adapters for filesystem-centric endpoints.""" + + def __init__(self, file_manager) -> None: + self._file_manager = file_manager + + async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse: + return await self._file_manager.open_folder(request) + + async def get_example_image_files(self, request: web.Request) -> web.StreamResponse: + return await self._file_manager.get_files(request) + + async def has_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._file_manager.has_images(request) + + +@dataclass(frozen=True) +class ExampleImagesHandlerSet: + """Aggregate of handlers exposed to the registrar.""" + + download: ExampleImagesDownloadHandler + management: ExampleImagesManagementHandler + files: ExampleImagesFileHandler + + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: + """Flatten handler methods into the registrar mapping.""" + + return { + "download_example_images": self.download.download_example_images, + "get_example_images_status": self.download.get_example_images_status, + "pause_example_images": self.download.pause_example_images, + "resume_example_images": self.download.resume_example_images, + "force_download_example_images": self.download.force_download_example_images, + "import_example_images": self.management.import_example_images, + "delete_example_image": self.management.delete_example_image, + "open_example_images_folder": self.files.open_example_images_folder, + "get_example_image_files": self.files.get_example_image_files, + "has_example_images": self.files.has_example_images, + } diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index dac40ee9..b9806dae 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -1,14 +1,21 @@ +from __future__ import annotations + from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, List, Set, Tuple +from typing import Any, List, Tuple from aiohttp import web from aiohttp.test_utils import TestClient, TestServer import pytest -from py.routes import example_images_routes -from py.routes.example_images_routes import ExampleImagesRoutes from py.routes.example_images_route_registrar import ROUTE_DEFINITIONS +from py.routes.example_images_routes import ExampleImagesRoutes +from py.routes.handlers.example_images_handlers import ( + ExampleImagesDownloadHandler, + ExampleImagesFileHandler, + ExampleImagesHandlerSet, + ExampleImagesManagementHandler, +) @dataclass @@ -16,85 +23,88 @@ class ExampleImagesHarness: """Container exposing the aiohttp client and stubbed collaborators.""" client: TestClient - download_manager: Any - processor: Any - file_manager: Any + download_manager: "StubDownloadManager" + processor: "StubExampleImagesProcessor" + file_manager: "StubExampleImagesFileManager" + controller: ExampleImagesRoutes + + +class StubDownloadManager: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def start_download(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("start_download", payload)) + return web.json_response({"operation": "start_download", "payload": payload}) + + async def get_status(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("get_status", dict(request.query))) + return web.json_response({"operation": "get_status"}) + + async def pause_download(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("pause_download", None)) + return web.json_response({"operation": "pause_download"}) + + async def resume_download(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("resume_download", None)) + return web.json_response({"operation": "resume_download"}) + + async def start_force_download(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("start_force_download", payload)) + return web.json_response({"operation": "start_force_download", "payload": payload}) + + +class StubExampleImagesProcessor: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def import_images(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("import_images", payload)) + return web.json_response({"operation": "import_images", "payload": payload}) + + async def delete_custom_image(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("delete_custom_image", payload)) + return web.json_response({"operation": "delete_custom_image", "payload": payload}) + + +class StubExampleImagesFileManager: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def open_folder(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("open_folder", payload)) + return web.json_response({"operation": "open_folder", "payload": payload}) + + async def get_files(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("get_files", dict(request.query))) + return web.json_response({"operation": "get_files", "query": dict(request.query)}) + + async def has_images(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("has_images", dict(request.query))) + return web.json_response({"operation": "has_images", "query": dict(request.query)}) @asynccontextmanager -async def example_images_app(monkeypatch: pytest.MonkeyPatch) -> ExampleImagesHarness: +async def example_images_app() -> ExampleImagesHarness: """Yield an ExampleImagesRoutes app wired with stubbed collaborators.""" - class StubDownloadManager: - calls: List[Tuple[str, Any]] = [] + download_manager = StubDownloadManager() + processor = StubExampleImagesProcessor() + file_manager = StubExampleImagesFileManager() - @staticmethod - async def start_download(request): - payload = await request.json() - StubDownloadManager.calls.append(("start_download", payload)) - return web.json_response({"operation": "start_download", "payload": payload}) - - @staticmethod - async def get_status(request): - StubDownloadManager.calls.append(("get_status", dict(request.query))) - return web.json_response({"operation": "get_status"}) - - @staticmethod - async def pause_download(request): - StubDownloadManager.calls.append(("pause_download", None)) - return web.json_response({"operation": "pause_download"}) - - @staticmethod - async def resume_download(request): - StubDownloadManager.calls.append(("resume_download", None)) - return web.json_response({"operation": "resume_download"}) - - @staticmethod - async def start_force_download(request): - payload = await request.json() - StubDownloadManager.calls.append(("start_force_download", payload)) - return web.json_response({"operation": "start_force_download", "payload": payload}) - - class StubExampleImagesProcessor: - calls: List[Tuple[str, Any]] = [] - - @staticmethod - async def import_images(request): - payload = await request.json() - StubExampleImagesProcessor.calls.append(("import_images", payload)) - return web.json_response({"operation": "import_images", "payload": payload}) - - @staticmethod - async def delete_custom_image(request): - payload = await request.json() - StubExampleImagesProcessor.calls.append(("delete_custom_image", payload)) - return web.json_response({"operation": "delete_custom_image", "payload": payload}) - - class StubExampleImagesFileManager: - calls: List[Tuple[str, Any]] = [] - - @staticmethod - async def open_folder(request): - payload = await request.json() - StubExampleImagesFileManager.calls.append(("open_folder", payload)) - return web.json_response({"operation": "open_folder", "payload": payload}) - - @staticmethod - async def get_files(request): - StubExampleImagesFileManager.calls.append(("get_files", dict(request.query))) - return web.json_response({"operation": "get_files", "query": dict(request.query)}) - - @staticmethod - async def has_images(request): - StubExampleImagesFileManager.calls.append(("has_images", dict(request.query))) - return web.json_response({"operation": "has_images", "query": dict(request.query)}) - - monkeypatch.setattr(example_images_routes, "DownloadManager", StubDownloadManager) - monkeypatch.setattr(example_images_routes, "ExampleImagesProcessor", StubExampleImagesProcessor) - monkeypatch.setattr(example_images_routes, "ExampleImagesFileManager", StubExampleImagesFileManager) + controller = ExampleImagesRoutes( + download_manager=download_manager, + processor=processor, + file_manager=file_manager, + ) app = web.Application() - ExampleImagesRoutes.setup_routes(app) + controller.register(app) server = TestServer(app) client = TestClient(server) @@ -103,17 +113,18 @@ async def example_images_app(monkeypatch: pytest.MonkeyPatch) -> ExampleImagesHa try: yield ExampleImagesHarness( client=client, - download_manager=StubDownloadManager, - processor=StubExampleImagesProcessor, - file_manager=StubExampleImagesFileManager, + download_manager=download_manager, + processor=processor, + file_manager=file_manager, + controller=controller, ) finally: await client.close() -async def test_setup_routes_registers_all_definitions(monkeypatch: pytest.MonkeyPatch): - async with example_images_app(monkeypatch) as harness: - registered: Set[tuple[str, str]] = { +async def test_setup_routes_registers_all_definitions(): + async with example_images_app() as harness: + registered = { (route.method, route.resource.canonical) for route in harness.client.app.router.routes() if route.resource.canonical @@ -131,8 +142,8 @@ async def test_setup_routes_registers_all_definitions(monkeypatch: pytest.Monkey ("/api/lm/force-download-example-images", {"model_hashes": ["abc123"]}), ], ) -async def test_download_routes_delegate_to_manager(endpoint, payload, monkeypatch: pytest.MonkeyPatch): - async with example_images_app(monkeypatch) as harness: +async def test_download_routes_delegate_to_manager(endpoint, payload): + async with example_images_app() as harness: response = await harness.client.post(endpoint, json=payload) body = await response.json() @@ -144,8 +155,8 @@ async def test_download_routes_delegate_to_manager(endpoint, payload, monkeypatc assert expected_call in harness.download_manager.calls -async def test_status_route_returns_manager_payload(monkeypatch: pytest.MonkeyPatch): - async with example_images_app(monkeypatch) as harness: +async def test_status_route_returns_manager_payload(): + async with example_images_app() as harness: response = await harness.client.get( "/api/lm/example-images-status", params={"detail": "true"} ) @@ -156,8 +167,8 @@ async def test_status_route_returns_manager_payload(monkeypatch: pytest.MonkeyPa assert harness.download_manager.calls == [("get_status", {"detail": "true"})] -async def test_pause_and_resume_routes_delegate(monkeypatch: pytest.MonkeyPatch): - async with example_images_app(monkeypatch) as harness: +async def test_pause_and_resume_routes_delegate(): + async with example_images_app() as harness: pause_response = await harness.client.post("/api/lm/pause-example-images") resume_response = await harness.client.post("/api/lm/resume-example-images") @@ -172,9 +183,9 @@ async def test_pause_and_resume_routes_delegate(monkeypatch: pytest.MonkeyPatch) ] -async def test_import_route_delegates_to_processor(monkeypatch: pytest.MonkeyPatch): +async def test_import_route_delegates_to_processor(): payload = {"model_hash": "abc123", "files": ["/path/image.png"]} - async with example_images_app(monkeypatch) as harness: + async with example_images_app() as harness: response = await harness.client.post( "/api/lm/import-example-images", json=payload ) @@ -185,9 +196,9 @@ async def test_import_route_delegates_to_processor(monkeypatch: pytest.MonkeyPat assert harness.processor.calls == [("import_images", payload)] -async def test_delete_route_delegates_to_processor(monkeypatch: pytest.MonkeyPatch): +async def test_delete_route_delegates_to_processor(): payload = {"model_hash": "abc123", "short_id": "xyz"} - async with example_images_app(monkeypatch) as harness: + async with example_images_app() as harness: response = await harness.client.post( "/api/lm/delete-example-image", json=payload ) @@ -198,11 +209,11 @@ async def test_delete_route_delegates_to_processor(monkeypatch: pytest.MonkeyPat assert harness.processor.calls == [("delete_custom_image", payload)] -async def test_file_routes_delegate_to_file_manager(monkeypatch: pytest.MonkeyPatch): +async def test_file_routes_delegate_to_file_manager(): open_payload = {"model_hash": "abc123"} files_params = {"model_hash": "def456"} - async with example_images_app(monkeypatch) as harness: + async with example_images_app() as harness: open_response = await harness.client.post( "/api/lm/open-example-images-folder", json=open_payload ) @@ -232,3 +243,137 @@ async def test_file_routes_delegate_to_file_manager(monkeypatch: pytest.MonkeyPa ("get_files", files_params), ("has_images", files_params), ] + + +@pytest.mark.asyncio +async def test_download_handler_methods_delegate() -> None: + class Recorder: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def start_download(self, request) -> str: + self.calls.append(("start_download", request)) + return "download" + + async def get_status(self, request) -> str: + self.calls.append(("get_status", request)) + return "status" + + async def pause_download(self, request) -> str: + self.calls.append(("pause_download", request)) + return "pause" + + async def resume_download(self, request) -> str: + self.calls.append(("resume_download", request)) + return "resume" + + async def start_force_download(self, request) -> str: + self.calls.append(("start_force_download", request)) + return "force" + + recorder = Recorder() + handler = ExampleImagesDownloadHandler(recorder) + request = object() + + assert await handler.download_example_images(request) == "download" + assert await handler.get_example_images_status(request) == "status" + assert await handler.pause_example_images(request) == "pause" + assert await handler.resume_example_images(request) == "resume" + assert await handler.force_download_example_images(request) == "force" + + expected = [ + ("start_download", request), + ("get_status", request), + ("pause_download", request), + ("resume_download", request), + ("start_force_download", request), + ] + assert recorder.calls == expected + + +@pytest.mark.asyncio +async def test_management_handler_methods_delegate() -> None: + class Recorder: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def import_images(self, request) -> str: + self.calls.append(("import_images", request)) + return "import" + + async def delete_custom_image(self, request) -> str: + self.calls.append(("delete_custom_image", request)) + return "delete" + + recorder = Recorder() + handler = ExampleImagesManagementHandler(recorder) + request = object() + + assert await handler.import_example_images(request) == "import" + assert await handler.delete_example_image(request) == "delete" + assert recorder.calls == [ + ("import_images", request), + ("delete_custom_image", request), + ] + + +@pytest.mark.asyncio +async def test_file_handler_methods_delegate() -> None: + class Recorder: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def open_folder(self, request) -> str: + self.calls.append(("open_folder", request)) + return "open" + + async def get_files(self, request) -> str: + self.calls.append(("get_files", request)) + return "files" + + async def has_images(self, request) -> str: + self.calls.append(("has_images", request)) + return "has" + + recorder = Recorder() + handler = ExampleImagesFileHandler(recorder) + request = object() + + assert await handler.open_example_images_folder(request) == "open" + assert await handler.get_example_image_files(request) == "files" + assert await handler.has_example_images(request) == "has" + assert recorder.calls == [ + ("open_folder", request), + ("get_files", request), + ("has_images", request), + ] + + +def test_handler_set_route_mapping_includes_all_handlers() -> None: + download = ExampleImagesDownloadHandler(object()) + management = ExampleImagesManagementHandler(object()) + files = ExampleImagesFileHandler(object()) + handler_set = ExampleImagesHandlerSet( + download=download, + management=management, + files=files, + ) + + mapping = handler_set.to_route_mapping() + + expected_keys = { + "download_example_images", + "get_example_images_status", + "pause_example_images", + "resume_example_images", + "force_download_example_images", + "import_example_images", + "delete_example_image", + "open_example_images_folder", + "get_example_image_files", + "has_example_images", + } + + assert mapping.keys() == expected_keys + for key in expected_keys: + assert callable(mapping[key]) From aaad270822c16b9acac418d5f0161b15d49a7d29 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 11:47:12 +0800 Subject: [PATCH 18/24] feat(example-images): add use case orchestration --- py/routes/example_images_routes.py | 10 +- py/routes/handlers/example_images_handlers.py | 80 ++++++++- py/services/use_cases/__init__.py | 12 ++ .../use_cases/example_images/__init__.py | 19 ++ .../download_example_images_use_case.py | 42 +++++ .../import_example_images_use_case.py | 86 +++++++++ py/utils/example_images_download_manager.py | 170 ++++++++---------- py/utils/example_images_processor.py | 151 +++++----------- tests/routes/test_example_images_routes.py | 148 +++++++++------ tests/services/test_use_cases.py | 126 +++++++++++++ 10 files changed, 582 insertions(+), 262 deletions(-) create mode 100644 py/services/use_cases/example_images/__init__.py create mode 100644 py/services/use_cases/example_images/download_example_images_use_case.py create mode 100644 py/services/use_cases/example_images/import_example_images_use_case.py diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 829760c2..44effa3b 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -12,6 +12,10 @@ from .handlers.example_images_handlers import ( ExampleImagesHandlerSet, ExampleImagesManagementHandler, ) +from ..services.use_cases.example_images import ( + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, +) from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_processor import ExampleImagesProcessor @@ -59,8 +63,10 @@ class ExampleImagesRoutes: def _build_handler_set(self) -> ExampleImagesHandlerSet: logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager) - download_handler = ExampleImagesDownloadHandler(self._download_manager) - management_handler = ExampleImagesManagementHandler(self._processor) + download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager) + download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager) + import_use_case = ImportExampleImagesUseCase(processor=self._processor) + management_handler = ExampleImagesManagementHandler(import_use_case, self._processor) file_handler = ExampleImagesFileHandler(self._file_manager) return ExampleImagesHandlerSet( download=download_handler, diff --git a/py/routes/handlers/example_images_handlers.py b/py/routes/handlers/example_images_handlers.py index 3d960338..fd39de04 100644 --- a/py/routes/handlers/example_images_handlers.py +++ b/py/routes/handlers/example_images_handlers.py @@ -6,37 +6,101 @@ from typing import Callable, Mapping from aiohttp import web +from ...services.use_cases.example_images import ( + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) +from ...utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + DownloadNotRunningError, + ExampleImagesDownloadError, +) +from ...utils.example_images_processor import ExampleImagesImportError + class ExampleImagesDownloadHandler: """HTTP adapters for download-related example image endpoints.""" - def __init__(self, download_manager) -> None: + def __init__( + self, + download_use_case: DownloadExampleImagesUseCase, + download_manager, + ) -> None: + self._download_use_case = download_use_case self._download_manager = download_manager async def download_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.start_download(request) + try: + payload = await request.json() + result = await self._download_use_case.execute(payload) + return web.json_response(result) + except DownloadExampleImagesInProgressError as exc: + response = { + 'success': False, + 'error': str(exc), + 'status': exc.progress, + } + return web.json_response(response, status=400) + except DownloadExampleImagesConfigurationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesDownloadError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) async def get_example_images_status(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.get_status(request) + result = await self._download_manager.get_status(request) + return web.json_response(result) async def pause_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.pause_download(request) + try: + result = await self._download_manager.pause_download(request) + return web.json_response(result) + except DownloadNotRunningError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) async def resume_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.resume_download(request) + try: + result = await self._download_manager.resume_download(request) + return web.json_response(result) + except DownloadNotRunningError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) async def force_download_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._download_manager.start_force_download(request) + try: + payload = await request.json() + result = await self._download_manager.start_force_download(payload) + return web.json_response(result) + except DownloadInProgressError as exc: + response = { + 'success': False, + 'error': str(exc), + 'status': exc.progress_snapshot, + } + return web.json_response(response, status=400) + except DownloadConfigurationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesDownloadError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) class ExampleImagesManagementHandler: """HTTP adapters for import/delete endpoints.""" - def __init__(self, processor) -> None: + def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None: + self._import_use_case = import_use_case self._processor = processor async def import_example_images(self, request: web.Request) -> web.StreamResponse: - return await self._processor.import_images(request) + try: + result = await self._import_use_case.execute(request) + return web.json_response(result) + except ImportExampleImagesValidationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesImportError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) async def delete_example_image(self, request: web.Request) -> web.StreamResponse: return await self._processor.delete_custom_image(request) diff --git a/py/services/use_cases/__init__.py b/py/services/use_cases/__init__.py index 986f0f57..8a43318c 100644 --- a/py/services/use_cases/__init__.py +++ b/py/services/use_cases/__init__.py @@ -13,6 +13,13 @@ from .download_model_use_case import ( DownloadModelUseCase, DownloadModelValidationError, ) +from .example_images import ( + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) __all__ = [ "AutoOrganizeInProgressError", @@ -22,4 +29,9 @@ __all__ = [ "DownloadModelEarlyAccessError", "DownloadModelUseCase", "DownloadModelValidationError", + "DownloadExampleImagesConfigurationError", + "DownloadExampleImagesInProgressError", + "DownloadExampleImagesUseCase", + "ImportExampleImagesUseCase", + "ImportExampleImagesValidationError", ] diff --git a/py/services/use_cases/example_images/__init__.py b/py/services/use_cases/example_images/__init__.py new file mode 100644 index 00000000..820de618 --- /dev/null +++ b/py/services/use_cases/example_images/__init__.py @@ -0,0 +1,19 @@ +"""Example image specific use case exports.""" + +from .download_example_images_use_case import ( + DownloadExampleImagesUseCase, + DownloadExampleImagesInProgressError, + DownloadExampleImagesConfigurationError, +) +from .import_example_images_use_case import ( + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) + +__all__ = [ + "DownloadExampleImagesUseCase", + "DownloadExampleImagesInProgressError", + "DownloadExampleImagesConfigurationError", + "ImportExampleImagesUseCase", + "ImportExampleImagesValidationError", +] diff --git a/py/services/use_cases/example_images/download_example_images_use_case.py b/py/services/use_cases/example_images/download_example_images_use_case.py new file mode 100644 index 00000000..e9a51e13 --- /dev/null +++ b/py/services/use_cases/example_images/download_example_images_use_case.py @@ -0,0 +1,42 @@ +"""Use case coordinating example image downloads.""" + +from __future__ import annotations + +from typing import Any, Dict + +from ....utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + ExampleImagesDownloadError, +) + + +class DownloadExampleImagesInProgressError(RuntimeError): + """Raised when a download is already running.""" + + def __init__(self, progress: Dict[str, Any]) -> None: + super().__init__("Download already in progress") + self.progress = progress + + +class DownloadExampleImagesConfigurationError(ValueError): + """Raised when settings prevent downloads from starting.""" + + +class DownloadExampleImagesUseCase: + """Validate payloads and trigger the download manager.""" + + def __init__(self, *, download_manager) -> None: + self._download_manager = download_manager + + async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Start a download and translate manager errors.""" + + try: + return await self._download_manager.start_download(payload) + except DownloadInProgressError as exc: + raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc + except DownloadConfigurationError as exc: + raise DownloadExampleImagesConfigurationError(str(exc)) from exc + except ExampleImagesDownloadError: + raise diff --git a/py/services/use_cases/example_images/import_example_images_use_case.py b/py/services/use_cases/example_images/import_example_images_use_case.py new file mode 100644 index 00000000..547b2f4e --- /dev/null +++ b/py/services/use_cases/example_images/import_example_images_use_case.py @@ -0,0 +1,86 @@ +"""Use case for importing example images.""" + +from __future__ import annotations + +import os +import tempfile +from contextlib import suppress +from typing import Any, Dict, List + +from aiohttp import web + +from ....utils.example_images_processor import ( + ExampleImagesImportError, + ExampleImagesProcessor, + ExampleImagesValidationError, +) + + +class ImportExampleImagesValidationError(ValueError): + """Raised when request validation fails.""" + + +class ImportExampleImagesUseCase: + """Parse upload payloads and delegate to the processor service.""" + + def __init__(self, *, processor: ExampleImagesProcessor) -> None: + self._processor = processor + + async def execute(self, request: web.Request) -> Dict[str, Any]: + model_hash: str | None = None + files_to_import: List[str] = [] + temp_files: List[str] = [] + + try: + if request.content_type and "multipart/form-data" in request.content_type: + reader = await request.multipart() + + first_field = await reader.next() + if first_field and first_field.name == "model_hash": + model_hash = await first_field.text() + else: + # Support clients that send files first and hash later + if first_field is not None: + await self._collect_upload_file(first_field, files_to_import, temp_files) + + async for field in reader: + if field.name == "model_hash" and not model_hash: + model_hash = await field.text() + elif field.name == "files": + await self._collect_upload_file(field, files_to_import, temp_files) + else: + data = await request.json() + model_hash = data.get("model_hash") + files_to_import = list(data.get("file_paths", [])) + + result = await self._processor.import_images(model_hash, files_to_import) + return result + except ExampleImagesValidationError as exc: + raise ImportExampleImagesValidationError(str(exc)) from exc + except ExampleImagesImportError: + raise + finally: + for path in temp_files: + with suppress(Exception): + os.remove(path) + + async def _collect_upload_file( + self, + field: Any, + files_to_import: List[str], + temp_files: List[str], + ) -> None: + """Persist an uploaded file to disk and add it to the import list.""" + + filename = field.filename or "upload" + file_ext = os.path.splitext(filename)[1].lower() + + with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: + temp_files.append(tmp_file.name) + while True: + chunk = await field.read_chunk() + if not chunk: + break + tmp_file.write(chunk) + + files_to_import.append(tmp_file.name) diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 842192f2..7df0c6fb 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -3,7 +3,6 @@ import os import asyncio import json import time -from aiohttp import web from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor @@ -12,6 +11,30 @@ from ..services.websocket_manager import ws_manager # Add this import at the to from ..services.downloader import get_downloader from ..services.settings_manager import settings + +class ExampleImagesDownloadError(RuntimeError): + """Base error for example image download operations.""" + + +class DownloadInProgressError(ExampleImagesDownloadError): + """Raised when a download is already running.""" + + def __init__(self, progress_snapshot: dict) -> None: + super().__init__("Download already in progress") + self.progress_snapshot = progress_snapshot + + +class DownloadNotRunningError(ExampleImagesDownloadError): + """Raised when pause/resume is requested without an active download.""" + + def __init__(self, message: str = "No download in progress") -> None: + super().__init__(message) + + +class DownloadConfigurationError(ExampleImagesDownloadError): + """Raised when configuration prevents starting a download.""" + + logger = logging.getLogger(__name__) # Download status tracking @@ -31,11 +54,21 @@ download_progress = { 'failed_models': set() # Track models that failed to download after metadata refresh } + +def _serialize_progress() -> dict: + """Return a JSON-serialisable snapshot of the current progress.""" + + snapshot = download_progress.copy() + snapshot['processed_models'] = list(download_progress['processed_models']) + snapshot['refreshed_models'] = list(download_progress['refreshed_models']) + snapshot['failed_models'] = list(download_progress['failed_models']) + return snapshot + class DownloadManager: """Manages downloading example images for models""" @staticmethod - async def start_download(request): + async def start_download(options: dict): """ Start downloading example images for models @@ -50,25 +83,14 @@ class DownloadManager: global download_task, is_downloading, download_progress if is_downloading: - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ - 'success': False, - 'error': 'Download already in progress', - 'status': response_progress - }, status=400) - + raise DownloadInProgressError(_serialize_progress()) + try: - # Parse the request body - data = await request.json() + data = options or {} auto_mode = data.get('auto_mode', False) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds # Get output directory from settings output_dir = settings.get('example_images_path') @@ -78,15 +100,11 @@ class DownloadManager: if auto_mode: # For auto mode, just log and return success to avoid showing error toasts logger.debug(error_msg) - return web.json_response({ + return { 'success': True, 'message': 'Example images path not configured, skipping auto download' - }) - else: - return web.json_response({ - 'success': False, - 'error': error_msg - }, status=400) + } + raise DownloadConfigurationError(error_msg) # Create the output directory os.makedirs(output_dir, exist_ok=True) @@ -129,41 +147,29 @@ class DownloadManager: ) ) - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ + return { 'success': True, 'message': 'Download started', - 'status': response_progress - }) - + 'status': _serialize_progress() + } + except Exception as e: logger.error(f"Failed to start example images download: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + raise ExampleImagesDownloadError(str(e)) from e @staticmethod async def get_status(request): """Get the current status of example images download""" global download_progress - + # Create a copy of the progress dict with the set converted to a list for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ + response_progress = _serialize_progress() + + return { 'success': True, 'is_downloading': is_downloading, 'status': response_progress - }) + } @staticmethod async def pause_download(request): @@ -171,17 +177,14 @@ class DownloadManager: global download_progress if not is_downloading: - return web.json_response({ - 'success': False, - 'error': 'No download in progress' - }, status=400) - + raise DownloadNotRunningError() + download_progress['status'] = 'paused' - - return web.json_response({ + + return { 'success': True, 'message': 'Download paused' - }) + } @staticmethod async def resume_download(request): @@ -189,23 +192,19 @@ class DownloadManager: global download_progress if not is_downloading: - return web.json_response({ - 'success': False, - 'error': 'No download in progress' - }, status=400) - + raise DownloadNotRunningError() + if download_progress['status'] == 'paused': download_progress['status'] = 'running' - - return web.json_response({ + + return { 'success': True, 'message': 'Download resumed' - }) - else: - return web.json_response({ - 'success': False, - 'error': f"Download is in '{download_progress['status']}' state, cannot resume" - }, status=400) + } + + raise DownloadNotRunningError( + f"Download is in '{download_progress['status']}' state, cannot resume" + ) @staticmethod async def _download_all_example_images(output_dir, optimize, model_types, delay): @@ -432,7 +431,7 @@ class DownloadManager: logger.error(f"Failed to save progress file: {e}") @staticmethod - async def start_force_download(request): + async def start_force_download(options: dict): """ Force download example images for specific models @@ -447,33 +446,23 @@ class DownloadManager: global download_task, is_downloading, download_progress if is_downloading: - return web.json_response({ - 'success': False, - 'error': 'Download already in progress' - }, status=400) + raise DownloadInProgressError(_serialize_progress()) try: - # Parse the request body - data = await request.json() + data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds - + delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + if not model_hashes: - return web.json_response({ - 'success': False, - 'error': 'Missing model_hashes parameter' - }, status=400) - + raise DownloadConfigurationError('Missing model_hashes parameter') + # Get output directory from settings output_dir = settings.get('example_images_path') - + if not output_dir: - return web.json_response({ - 'success': False, - 'error': 'Example images path not configured in settings' - }, status=400) + raise DownloadConfigurationError('Example images path not configured in settings') # Create the output directory os.makedirs(output_dir, exist_ok=True) @@ -506,20 +495,17 @@ class DownloadManager: # Set download status to not downloading is_downloading = False - return web.json_response({ + return { 'success': True, 'message': 'Force download completed', 'result': result - }) + } except Exception as e: # Set download status to not downloading is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + raise ExampleImagesDownloadError(str(e)) from e @staticmethod async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): diff --git a/py/utils/example_images_processor.py b/py/utils/example_images_processor.py index f1cfd2bf..7f108ef9 100644 --- a/py/utils/example_images_processor.py +++ b/py/utils/example_images_processor.py @@ -1,7 +1,6 @@ import logging import os import re -import tempfile import random import string from aiohttp import web @@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager logger = logging.getLogger(__name__) + +class ExampleImagesImportError(RuntimeError): + """Base error for example image import operations.""" + + +class ExampleImagesValidationError(ExampleImagesImportError): + """Raised when input validation fails.""" + class ExampleImagesProcessor: """Processes and manipulates example images""" @@ -299,90 +306,29 @@ class ExampleImagesProcessor: return False @staticmethod - async def import_images(request): - """ - Import local example images - - Accepts: - - multipart/form-data form with model_hash and files fields - or - - JSON request with model_hash and file_paths - - Returns: - - Success status and list of imported files - """ + async def import_images(model_hash: str, files_to_import: list[str]): + """Import local example images for a model.""" + + if not model_hash: + raise ExampleImagesValidationError('Missing model_hash parameter') + + if not files_to_import: + raise ExampleImagesValidationError('No files provided to import') + try: - model_hash = None - files_to_import = [] - temp_files_to_cleanup = [] - - # Check if it's a multipart form-data request (direct file upload) - if request.content_type and 'multipart/form-data' in request.content_type: - reader = await request.multipart() - - # First get model_hash - field = await reader.next() - if field.name == 'model_hash': - model_hash = await field.text() - - # Then process all files - while True: - field = await reader.next() - if field is None: - break - - if field.name == 'files': - # Create a temporary file with appropriate suffix for type detection - file_name = field.filename - file_ext = os.path.splitext(file_name)[1].lower() - - with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: - temp_path = tmp_file.name - temp_files_to_cleanup.append(temp_path) # Track for cleanup - - # Write chunks to the temporary file - while True: - chunk = await field.read_chunk() - if not chunk: - break - tmp_file.write(chunk) - - # Add to the list of files to process - files_to_import.append(temp_path) - else: - # Parse JSON request (legacy method using file paths) - data = await request.json() - model_hash = data.get('model_hash') - files_to_import = data.get('file_paths', []) - - if not model_hash: - return web.json_response({ - 'success': False, - 'error': 'Missing model_hash parameter' - }, status=400) - - if not files_to_import: - return web.json_response({ - 'success': False, - 'error': 'No files provided to import' - }, status=400) - # Get example images path example_images_path = settings.get('example_images_path') if not example_images_path: - return web.json_response({ - 'success': False, - 'error': 'No example images path configured' - }, status=400) - + raise ExampleImagesValidationError('No example images path configured') + # Find the model and get current metadata lora_scanner = await ServiceRegistry.get_lora_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner() - + model_data = None scanner = None - + # Check both scanners to find the model for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]: cache = await scan_obj.get_cached_data() @@ -393,21 +339,20 @@ class ExampleImagesProcessor: break if model_data: break - + if not model_data: - return web.json_response({ - 'success': False, - 'error': f"Model with hash {model_hash} not found in cache" - }, status=404) - + raise ExampleImagesImportError( + f"Model with hash {model_hash} not found in cache" + ) + # Create model folder model_folder = os.path.join(example_images_path, model_hash) os.makedirs(model_folder, exist_ok=True) - + imported_files = [] errors = [] newly_imported_paths = [] - + # Process each file path for file_path in files_to_import: try: @@ -415,26 +360,26 @@ class ExampleImagesProcessor: if not os.path.isfile(file_path): errors.append(f"File not found: {file_path}") continue - + # Check if file type is supported file_ext = os.path.splitext(file_path)[1].lower() - if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or + if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']): errors.append(f"Unsupported file type: {file_path}") continue - + # Generate new filename using short ID instead of UUID short_id = ExampleImagesProcessor.generate_short_id() new_filename = f"custom_{short_id}{file_ext}" - + dest_path = os.path.join(model_folder, new_filename) - + # Copy the file import shutil shutil.copy2(file_path, dest_path) # Store both the dest_path and the short_id newly_imported_paths.append((dest_path, short_id)) - + # Add to imported files list imported_files.append({ 'name': new_filename, @@ -444,39 +389,31 @@ class ExampleImagesProcessor: }) except Exception as e: errors.append(f"Error importing {file_path}: {str(e)}") - + # Update metadata with new example images regular_images, custom_images = await MetadataUpdater.update_metadata_after_import( - model_hash, + model_hash, model_data, scanner, newly_imported_paths ) - - return web.json_response({ + + return { 'success': len(imported_files) > 0, - 'message': f'Successfully imported {len(imported_files)} files' + + 'message': f'Successfully imported {len(imported_files)} files' + (f' with {len(errors)} errors' if errors else ''), 'files': imported_files, 'errors': errors, 'regular_images': regular_images, 'custom_images': custom_images, "model_file_path": model_data.get('file_path', ''), - }) - + } + + except ExampleImagesImportError: + raise except Exception as e: logger.error(f"Failed to import example images: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - finally: - # Clean up temporary files - for temp_file in temp_files_to_cleanup: - try: - os.remove(temp_file) - except Exception as e: - logger.error(f"Failed to remove temporary file {temp_file}: {e}") + raise ExampleImagesImportError(str(e)) from e @staticmethod async def delete_custom_image(request): diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index b9806dae..e921e744 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, List, Tuple @@ -33,37 +34,35 @@ class StubDownloadManager: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def start_download(self, request: web.Request) -> web.StreamResponse: - payload = await request.json() + async def start_download(self, payload: Any) -> dict: self.calls.append(("start_download", payload)) - return web.json_response({"operation": "start_download", "payload": payload}) + return {"operation": "start_download", "payload": payload} - async def get_status(self, request: web.Request) -> web.StreamResponse: + async def get_status(self, request: web.Request) -> dict: self.calls.append(("get_status", dict(request.query))) - return web.json_response({"operation": "get_status"}) + return {"operation": "get_status"} - async def pause_download(self, request: web.Request) -> web.StreamResponse: + async def pause_download(self, request: web.Request) -> dict: self.calls.append(("pause_download", None)) - return web.json_response({"operation": "pause_download"}) + return {"operation": "pause_download"} - async def resume_download(self, request: web.Request) -> web.StreamResponse: + async def resume_download(self, request: web.Request) -> dict: self.calls.append(("resume_download", None)) - return web.json_response({"operation": "resume_download"}) + return {"operation": "resume_download"} - async def start_force_download(self, request: web.Request) -> web.StreamResponse: - payload = await request.json() + async def start_force_download(self, payload: Any) -> dict: self.calls.append(("start_force_download", payload)) - return web.json_response({"operation": "start_force_download", "payload": payload}) + return {"operation": "start_force_download", "payload": payload} class StubExampleImagesProcessor: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def import_images(self, request: web.Request) -> web.StreamResponse: - payload = await request.json() + async def import_images(self, model_hash: str, files: List[str]) -> dict: + payload = {"model_hash": model_hash, "file_paths": files} self.calls.append(("import_images", payload)) - return web.json_response({"operation": "import_images", "payload": payload}) + return {"operation": "import_images", "payload": payload} async def delete_custom_image(self, request: web.Request) -> web.StreamResponse: payload = await request.json() @@ -184,7 +183,7 @@ async def test_pause_and_resume_routes_delegate(): async def test_import_route_delegates_to_processor(): - payload = {"model_hash": "abc123", "files": ["/path/image.png"]} + payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]} async with example_images_app() as harness: response = await harness.client.post( "/api/lm/import-example-images", json=payload @@ -193,7 +192,8 @@ async def test_import_route_delegates_to_processor(): assert response.status == 200 assert body == {"operation": "import_images", "payload": payload} - assert harness.processor.calls == [("import_images", payload)] + expected_call = ("import_images", payload) + assert expected_call in harness.processor.calls async def test_delete_route_delegates_to_processor(): @@ -251,70 +251,91 @@ async def test_download_handler_methods_delegate() -> None: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def start_download(self, request) -> str: - self.calls.append(("start_download", request)) - return "download" - - async def get_status(self, request) -> str: + async def get_status(self, request) -> dict: self.calls.append(("get_status", request)) - return "status" + return {"status": "ok"} - async def pause_download(self, request) -> str: + async def pause_download(self, request) -> dict: self.calls.append(("pause_download", request)) - return "pause" + return {"status": "paused"} - async def resume_download(self, request) -> str: + async def resume_download(self, request) -> dict: self.calls.append(("resume_download", request)) - return "resume" + return {"status": "running"} - async def start_force_download(self, request) -> str: - self.calls.append(("start_force_download", request)) - return "force" + async def start_force_download(self, payload) -> dict: + self.calls.append(("start_force_download", payload)) + return {"status": "force", "payload": payload} + + class StubDownloadUseCase: + def __init__(self) -> None: + self.payloads: List[Any] = [] + + async def execute(self, payload: dict) -> dict: + self.payloads.append(payload) + return {"status": "started", "payload": payload} + + class DummyRequest: + def __init__(self, payload: dict) -> None: + self._payload = payload + self.query = {} + + async def json(self) -> dict: + return self._payload recorder = Recorder() - handler = ExampleImagesDownloadHandler(recorder) - request = object() + use_case = StubDownloadUseCase() + handler = ExampleImagesDownloadHandler(use_case, recorder) + request = DummyRequest({"foo": "bar"}) - assert await handler.download_example_images(request) == "download" - assert await handler.get_example_images_status(request) == "status" - assert await handler.pause_example_images(request) == "pause" - assert await handler.resume_example_images(request) == "resume" - assert await handler.force_download_example_images(request) == "force" + download_response = await handler.download_example_images(request) + assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}} + status_response = await handler.get_example_images_status(request) + assert json.loads(status_response.text) == {"status": "ok"} + pause_response = await handler.pause_example_images(request) + assert json.loads(pause_response.text) == {"status": "paused"} + resume_response = await handler.resume_example_images(request) + assert json.loads(resume_response.text) == {"status": "running"} + force_response = await handler.force_download_example_images(request) + assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}} - expected = [ - ("start_download", request), + assert use_case.payloads == [{"foo": "bar"}] + assert recorder.calls == [ ("get_status", request), ("pause_download", request), ("resume_download", request), - ("start_force_download", request), + ("start_force_download", {"foo": "bar"}), ] - assert recorder.calls == expected @pytest.mark.asyncio async def test_management_handler_methods_delegate() -> None: + class StubImportUseCase: + def __init__(self) -> None: + self.requests: List[Any] = [] + + async def execute(self, request: Any) -> dict: + self.requests.append(request) + return {"status": "imported"} + class Recorder: def __init__(self) -> None: self.calls: List[Tuple[str, Any]] = [] - async def import_images(self, request) -> str: - self.calls.append(("import_images", request)) - return "import" - async def delete_custom_image(self, request) -> str: self.calls.append(("delete_custom_image", request)) return "delete" recorder = Recorder() - handler = ExampleImagesManagementHandler(recorder) + use_case = StubImportUseCase() + handler = ExampleImagesManagementHandler(use_case, recorder) request = object() - assert await handler.import_example_images(request) == "import" + import_response = await handler.import_example_images(request) + assert json.loads(import_response.text) == {"status": "imported"} assert await handler.delete_example_image(request) == "delete" - assert recorder.calls == [ - ("import_images", request), - ("delete_custom_image", request), - ] + assert use_case.requests == [request] + assert recorder.calls == [("delete_custom_image", request)] @pytest.mark.asyncio @@ -350,8 +371,29 @@ async def test_file_handler_methods_delegate() -> None: def test_handler_set_route_mapping_includes_all_handlers() -> None: - download = ExampleImagesDownloadHandler(object()) - management = ExampleImagesManagementHandler(object()) + class DummyUseCase: + async def execute(self, payload): + return payload + + class DummyManager: + async def get_status(self, request): + return {} + + async def pause_download(self, request): + return {} + + async def resume_download(self, request): + return {} + + async def start_force_download(self, payload): + return payload + + class DummyProcessor: + async def delete_custom_image(self, request): + return {} + + download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager()) + management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor()) files = ExampleImagesFileHandler(object()) handler_set = ExampleImagesHandlerSet( download=download, diff --git a/tests/services/test_use_cases.py b/tests/services/test_use_cases.py index 64057fc6..cfd0f10c 100644 --- a/tests/services/test_use_cases.py +++ b/tests/services/test_use_cases.py @@ -10,9 +10,23 @@ from py_local.services.use_cases import ( AutoOrganizeInProgressError, AutoOrganizeUseCase, BulkMetadataRefreshUseCase, + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, DownloadModelEarlyAccessError, DownloadModelUseCase, DownloadModelValidationError, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) +from py_local.utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + ExampleImagesDownloadError, +) +from py_local.utils.example_images_processor import ( + ExampleImagesImportError, + ExampleImagesValidationError, ) from tests.conftest import MockModelService, MockScanner @@ -88,6 +102,38 @@ class StubDownloadCoordinator: return {"success": True, "download_id": "abc123"} +class StubExampleImagesDownloadManager: + def __init__(self) -> None: + self.payloads: List[Dict[str, Any]] = [] + self.error: Optional[str] = None + self.progress_snapshot = {"status": "running"} + + async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + self.payloads.append(payload) + if self.error == "in_progress": + raise DownloadInProgressError(self.progress_snapshot) + if self.error == "configuration": + raise DownloadConfigurationError("path missing") + if self.error == "generic": + raise ExampleImagesDownloadError("boom") + return {"success": True, "message": "ok"} + + +class StubExampleImagesProcessor: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + self.error: Optional[str] = None + self.response: Dict[str, Any] = {"success": True} + + async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]: + self.calls.append({"model_hash": model_hash, "files": files}) + if self.error == "validation": + raise ExampleImagesValidationError("missing") + if self.error == "generic": + raise ExampleImagesImportError("boom") + return self.response + + async def test_auto_organize_use_case_executes_with_lock() -> None: file_service = StubFileService() lock_provider = StubLockProvider() @@ -189,3 +235,83 @@ async def test_download_model_use_case_returns_result() -> None: assert result["success"] is True assert result["download_id"] == "abc123" + + +async def test_download_example_images_use_case_triggers_manager() -> None: + manager = StubExampleImagesDownloadManager() + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + payload = {"optimize": True} + result = await use_case.execute(payload) + + assert manager.payloads == [payload] + assert result == {"success": True, "message": "ok"} + + +async def test_download_example_images_use_case_maps_in_progress() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "in_progress" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(DownloadExampleImagesInProgressError) as exc: + await use_case.execute({}) + + assert exc.value.progress == manager.progress_snapshot + + +async def test_download_example_images_use_case_maps_configuration() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "configuration" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(DownloadExampleImagesConfigurationError): + await use_case.execute({}) + + +async def test_download_example_images_use_case_propagates_generic_error() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "generic" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(ExampleImagesDownloadError): + await use_case.execute({}) + + +class DummyJsonRequest: + def __init__(self, payload: Dict[str, Any]) -> None: + self._payload = payload + self.content_type = "application/json" + + async def json(self) -> Dict[str, Any]: + return self._payload + + +async def test_import_example_images_use_case_delegates() -> None: + processor = StubExampleImagesProcessor() + use_case = ImportExampleImagesUseCase(processor=processor) + + request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]}) + result = await use_case.execute(request) + + assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}] + assert result == {"success": True} + + +async def test_import_example_images_use_case_maps_validation_error() -> None: + processor = StubExampleImagesProcessor() + processor.error = "validation" + use_case = ImportExampleImagesUseCase(processor=processor) + request = DummyJsonRequest({"model_hash": None, "file_paths": []}) + + with pytest.raises(ImportExampleImagesValidationError): + await use_case.execute(request) + + +async def test_import_example_images_use_case_propagates_generic_error() -> None: + processor = StubExampleImagesProcessor() + processor.error = "generic" + use_case = ImportExampleImagesUseCase(processor=processor) + request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]}) + + with pytest.raises(ExampleImagesImportError): + await use_case.execute(request) From 679cfb5c69d6e83e263b8b87fc5a0fe5f294d80a Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 13:07:11 +0800 Subject: [PATCH 19/24] refactor(example-images): encapsulate download manager state --- py/routes/example_images_routes.py | 9 +- py/utils/example_images_download_manager.py | 465 +++++++++----------- py/utils/example_images_metadata.py | 14 +- 3 files changed, 225 insertions(+), 263 deletions(-) diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 44effa3b..d5d34218 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -16,7 +16,10 @@ from ..services.use_cases.example_images import ( DownloadExampleImagesUseCase, ImportExampleImagesUseCase, ) -from ..utils.example_images_download_manager import DownloadManager +from ..utils.example_images_download_manager import ( + DownloadManager, + get_default_download_manager, +) from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_processor import ExampleImagesProcessor @@ -29,11 +32,11 @@ class ExampleImagesRoutes: def __init__( self, *, - download_manager=DownloadManager, + download_manager: DownloadManager | None = None, processor=ExampleImagesProcessor, file_manager=ExampleImagesFileManager, ) -> None: - self._download_manager = download_manager + self._download_manager = download_manager or get_default_download_manager() self._processor = processor self._file_manager = file_manager self._handler_set: ExampleImagesHandlerSet | None = None diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 7df0c6fb..e538f50a 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import asyncio @@ -37,165 +39,150 @@ class DownloadConfigurationError(ExampleImagesDownloadError): logger = logging.getLogger(__name__) -# Download status tracking -download_task = None -is_downloading = False -download_progress = { - 'total': 0, - 'completed': 0, - 'current_model': '', - 'status': 'idle', # idle, running, paused, completed, error - 'errors': [], - 'last_error': None, - 'start_time': None, - 'end_time': None, - 'processed_models': set(), # Track models that have been processed - 'refreshed_models': set(), # Track models that had metadata refreshed - 'failed_models': set() # Track models that failed to download after metadata refresh -} +class _DownloadProgress(dict): + """Mutable mapping maintaining download progress with set-aware serialisation.""" -def _serialize_progress() -> dict: - """Return a JSON-serialisable snapshot of the current progress.""" + def __init__(self) -> None: + super().__init__() + self.reset() - snapshot = download_progress.copy() - snapshot['processed_models'] = list(download_progress['processed_models']) - snapshot['refreshed_models'] = list(download_progress['refreshed_models']) - snapshot['failed_models'] = list(download_progress['failed_models']) - return snapshot + def reset(self) -> None: + """Reset the progress dictionary to its initial state.""" + + self.update( + total=0, + completed=0, + current_model='', + status='idle', + errors=[], + last_error=None, + start_time=None, + end_time=None, + processed_models=set(), + refreshed_models=set(), + failed_models=set(), + ) + + def snapshot(self) -> dict: + """Return a JSON-serialisable snapshot of the current progress.""" + + snapshot = dict(self) + snapshot['processed_models'] = list(self['processed_models']) + snapshot['refreshed_models'] = list(self['refreshed_models']) + snapshot['failed_models'] = list(self['failed_models']) + return snapshot class DownloadManager: - """Manages downloading example images for models""" - - @staticmethod - async def start_download(options: dict): - """ - Start downloading example images for models - - Expects a JSON body with: - { - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - "delay": 1.0, # Delay between downloads to avoid rate limiting (default: 1.0) - "auto_mode": false # Flag to indicate automatic download (default: false) - } - """ - global download_task, is_downloading, download_progress - - if is_downloading: - raise DownloadInProgressError(_serialize_progress()) + """Manages downloading example images for models.""" + + def __init__(self) -> None: + self._download_task: asyncio.Task | None = None + self._is_downloading = False + self._progress = _DownloadProgress() + + async def start_download(self, options: dict): + """Start downloading example images for models.""" + + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) try: data = options or {} auto_mode = data.get('auto_mode', False) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds - - # Get output directory from settings + delay = float(data.get('delay', 0.2)) + output_dir = settings.get('example_images_path') if not output_dir: error_msg = 'Example images path not configured in settings' if auto_mode: - # For auto mode, just log and return success to avoid showing error toasts logger.debug(error_msg) return { 'success': True, 'message': 'Example images path not configured, skipping auto download' } raise DownloadConfigurationError(error_msg) - - # Create the output directory + os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = 0 - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - - # Get the processed models list from a file if it exists + + self._progress.reset() + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None + progress_file = os.path.join(output_dir, '.download_progress.json') if os.path.exists(progress_file): try: with open(progress_file, 'r', encoding='utf-8') as f: saved_progress = json.load(f) - download_progress['processed_models'] = set(saved_progress.get('processed_models', [])) - download_progress['failed_models'] = set(saved_progress.get('failed_models', [])) - logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed") + self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) + self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) + logger.debug( + "Loaded previous progress, %s models already processed, %s models marked as failed", + len(self._progress['processed_models']), + len(self._progress['failed_models']), + ) except Exception as e: logger.error(f"Failed to load progress file: {e}") - download_progress['processed_models'] = set() - download_progress['failed_models'] = set() + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() else: - download_progress['processed_models'] = set() - download_progress['failed_models'] = set() - - # Start the download task - is_downloading = True - download_task = asyncio.create_task( - DownloadManager._download_all_example_images( - output_dir, - optimize, + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() + + self._is_downloading = True + self._download_task = asyncio.create_task( + self._download_all_example_images( + output_dir, + optimize, model_types, delay ) ) - + return { 'success': True, 'message': 'Download started', - 'status': _serialize_progress() + 'status': self._progress.snapshot() } except Exception as e: logger.error(f"Failed to start example images download: {e}", exc_info=True) raise ExampleImagesDownloadError(str(e)) from e - @staticmethod - async def get_status(request): - """Get the current status of example images download""" - global download_progress - - # Create a copy of the progress dict with the set converted to a list for JSON serialization - response_progress = _serialize_progress() + async def get_status(self, request): + """Get the current status of example images download.""" return { 'success': True, - 'is_downloading': is_downloading, - 'status': response_progress + 'is_downloading': self._is_downloading, + 'status': self._progress.snapshot(), } - @staticmethod - async def pause_download(request): - """Pause the example images download""" - global download_progress - - if not is_downloading: + async def pause_download(self, request): + """Pause the example images download.""" + + if not self._is_downloading: raise DownloadNotRunningError() - download_progress['status'] = 'paused' + self._progress['status'] = 'paused' return { 'success': True, 'message': 'Download paused' } - @staticmethod - async def resume_download(request): - """Resume the example images download""" - global download_progress - - if not is_downloading: + async def resume_download(self, request): + """Resume the example images download.""" + + if not self._is_downloading: raise DownloadNotRunningError() - if download_progress['status'] == 'paused': - download_progress['status'] = 'running' + if self._progress['status'] == 'paused': + self._progress['status'] = 'running' return { 'success': True, @@ -203,15 +190,12 @@ class DownloadManager: } raise DownloadNotRunningError( - f"Download is in '{download_progress['status']}' state, cannot resume" + f"Download is in '{self._progress['status']}' state, cannot resume" ) - @staticmethod - async def _download_all_example_images(output_dir, optimize, model_types, delay): - """Download example images for all models""" - global is_downloading, download_progress - - # Get unified downloader + async def _download_all_example_images(self, output_dir, optimize, model_types, delay): + """Download example images for all models.""" + downloader = await get_downloader() try: @@ -239,59 +223,58 @@ class DownloadManager: all_models.append((scanner_type, model, scanner)) # Update total count - download_progress['total'] = len(all_models) - logger.debug(f"Found {download_progress['total']} models to process") + self._progress['total'] = len(all_models) + logger.debug(f"Found {self._progress['total']} models to process") # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): # Main logic for processing model is here, but actual operations are delegated to other classes - was_remote_download = await DownloadManager._process_model( - scanner_type, model, scanner, + was_remote_download = await self._process_model( + scanner_type, model, scanner, output_dir, optimize, downloader ) # Update progress - download_progress['completed'] += 1 + self._progress['completed'] += 1 # Only add delay after remote download of models, and not after processing the last model - if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running': + if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() + logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() finally: # Save final progress to file try: - DownloadManager._save_progress(output_dir) + self._save_progress(output_dir) except Exception as e: logger.error(f"Failed to save progress file: {e}") - + # Set download status to not downloading - is_downloading = False + self._is_downloading = False + self._download_task = None - @staticmethod - async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader): - """Process a single model download""" - global download_progress - + async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + """Process a single model download.""" + # Check if download is paused - while download_progress['status'] == 'paused': + while self._progress['status'] == 'paused': await asyncio.sleep(1) - + # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Download stopped: {download_progress['status']}") + if self._progress['status'] != 'running': + logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened model_hash = model.get('sha256', '').lower() @@ -301,15 +284,15 @@ class DownloadManager: try: # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" # Skip if already in failed models - if model_hash in download_progress['failed_models']: + if model_hash in self._progress['failed_models']: logger.debug(f"Skipping known failed model: {model_name}") return False # Skip if already processed AND directory exists with files - if model_hash in download_progress['processed_models']: + if model_hash in self._progress['processed_models']: model_dir = os.path.join(output_dir, model_hash) has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) if has_files: @@ -318,7 +301,7 @@ class DownloadManager: else: logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") # Remove from processed models since we need to reprocess - download_progress['processed_models'].discard(model_hash) + self._progress['processed_models'].discard(model_hash) # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -334,7 +317,7 @@ class DownloadManager: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened # If no local images, try to download from remote @@ -346,57 +329,55 @@ class DownloadManager: ) # If metadata is stale, try to refresh it - if is_stale and model_hash not in download_progress['refreshed_models']: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( - model_hash, model_name, scanner_type, scanner + model_hash, model_name, scanner_type, scanner, self._progress ) - + # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - + if updated_model and updated_model.get('civitai', {}).get('images'): # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( model_hash, model_name, updated_images, model_dir, optimize, downloader ) - - download_progress['refreshed_models'].add(model_hash) + + self._progress['refreshed_models'].add(model_hash) # Mark as processed if successful, or as failed if unsuccessful after refresh if success: - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) else: # If we refreshed metadata and still failed, mark as permanently failed - if model_hash in download_progress['refreshed_models']: - download_progress['failed_models'].add(model_hash) + if model_hash in self._progress['refreshed_models']: + self._progress['failed_models'].add(model_hash) logger.info(f"Marking model {model_name} as failed after metadata refresh") return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts - download_progress['failed_models'].add(model_hash) + self._progress['failed_models'].add(model_hash) logger.debug(f"No civitai images available for model {model_name}, marking as failed") # Save progress periodically - if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: - DownloadManager._save_progress(output_dir) + if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: + self._save_progress(output_dir) return False # Default return if no conditions met except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception - @staticmethod - def _save_progress(output_dir): - """Save download progress to file""" - global download_progress + def _save_progress(self, output_dir): + """Save download progress to file.""" try: progress_file = os.path.join(output_dir, '.download_progress.json') @@ -411,11 +392,11 @@ class DownloadManager: # Create new progress data progress_data = { - 'processed_models': list(download_progress['processed_models']), - 'refreshed_models': list(download_progress['refreshed_models']), - 'failed_models': list(download_progress['failed_models']), - 'completed': download_progress['completed'], - 'total': download_progress['total'], + 'processed_models': list(self._progress['processed_models']), + 'refreshed_models': list(self._progress['refreshed_models']), + 'failed_models': list(self._progress['failed_models']), + 'completed': self._progress['completed'], + 'total': self._progress['total'], 'last_update': time.time() } @@ -430,70 +411,46 @@ class DownloadManager: except Exception as e: logger.error(f"Failed to save progress file: {e}") - @staticmethod - async def start_force_download(options: dict): - """ - Force download example images for specific models - - Expects a JSON body with: - { - "model_hashes": ["hash1", "hash2", ...], # List of model hashes to download - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - "delay": 1.0 # Delay between downloads (default: 1.0) - } - """ - global download_task, is_downloading, download_progress + async def start_force_download(self, options: dict): + """Force download example images for specific models.""" - if is_downloading: - raise DownloadInProgressError(_serialize_progress()) + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) try: data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds + delay = float(data.get('delay', 0.2)) if not model_hashes: raise DownloadConfigurationError('Missing model_hashes parameter') - # Get output directory from settings output_dir = settings.get('example_images_path') if not output_dir: raise DownloadConfigurationError('Example images path not configured in settings') - - # Create the output directory + os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = len(model_hashes) - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - download_progress['processed_models'] = set() - download_progress['refreshed_models'] = set() - download_progress['failed_models'] = set() - # Set download status to downloading - is_downloading = True + self._progress.reset() + self._progress['total'] = len(model_hashes) + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None - # Execute the download function directly instead of creating a background task - result = await DownloadManager._download_specific_models_example_images_sync( + self._is_downloading = True + + result = await self._download_specific_models_example_images_sync( model_hashes, - output_dir, - optimize, + output_dir, + optimize, model_types, delay ) - # Set download status to not downloading - is_downloading = False + self._is_downloading = False return { 'success': True, @@ -502,17 +459,13 @@ class DownloadManager: } except Exception as e: - # Set download status to not downloading - is_downloading = False + self._is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) raise ExampleImagesDownloadError(str(e)) from e - @staticmethod - async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): - """Download example images for specific models only - synchronous version""" - global download_progress - - # Get unified downloader + async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): + """Download example images for specific models only - synchronous version.""" + downloader = await get_downloader() try: @@ -540,14 +493,14 @@ class DownloadManager: models_to_process.append((scanner_type, model, scanner)) # Update total count based on found models - download_progress['total'] = len(models_to_process) - logger.debug(f"Found {download_progress['total']} models to process") + self._progress['total'] = len(models_to_process) + logger.debug(f"Found {self._progress['total']} models to process") # Send initial progress via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', 'processed': 0, - 'total': download_progress['total'], + 'total': self._progress['total'], 'status': 'running', 'current_model': '' }) @@ -556,8 +509,8 @@ class DownloadManager: success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): # Force process this model regardless of previous status - was_successful = await DownloadManager._process_specific_model( - scanner_type, model, scanner, + was_successful = await self._process_specific_model( + scanner_type, model, scanner, output_dir, optimize, downloader ) @@ -565,55 +518,55 @@ class DownloadManager: success_count += 1 # Update progress - download_progress['completed'] += 1 + self._progress['completed'] += 1 # Send progress update via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], + 'processed': self._progress['completed'], + 'total': self._progress['total'], 'status': 'running', - 'current_model': download_progress['current_model'] + 'current_model': self._progress['current_model'] }) # Only add delay after remote download, and not after processing the last model - if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running': + if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() + logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") # Send final progress via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], + 'processed': self._progress['completed'], + 'total': self._progress['total'], 'status': 'completed', 'current_model': '' }) return { - 'total': download_progress['total'], - 'processed': download_progress['completed'], + 'total': self._progress['total'], + 'processed': self._progress['completed'], 'successful': success_count, - 'errors': download_progress['errors'] + 'errors': self._progress['errors'] } except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() # Send error status via WebSocket await ws_manager.broadcast({ 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], + 'processed': self._progress['completed'], + 'total': self._progress['total'], 'status': 'error', 'error': error_msg, 'current_model': '' @@ -625,18 +578,16 @@ class DownloadManager: # No need to close any sessions since we use the global downloader pass - @staticmethod - async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader): - """Process a specific model for forced download, ignoring previous download status""" - global download_progress - + async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + """Process a specific model for forced download, ignoring previous download status.""" + # Check if download is paused - while download_progress['status'] == 'paused': + while self._progress['status'] == 'paused': await asyncio.sleep(1) # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Download stopped: {download_progress['status']}") + if self._progress['status'] != 'running': + logger.info(f"Download stopped: {self._progress['status']}") return False model_hash = model.get('sha256', '').lower() @@ -646,7 +597,7 @@ class DownloadManager: try: # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -662,7 +613,7 @@ class DownloadManager: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened # If no local images, try to download from remote @@ -674,9 +625,9 @@ class DownloadManager: ) # If metadata is stale, try to refresh it - if is_stale and model_hash not in download_progress['refreshed_models']: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( - model_hash, model_name, scanner_type, scanner + model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data @@ -694,18 +645,18 @@ class DownloadManager: # Combine failed images from both attempts failed_images.extend(additional_failed_images) - download_progress['refreshed_models'].add(model_hash) + self._progress['refreshed_models'].add(model_hash) # For forced downloads, remove failed images from metadata if failed_images: # Create a copy of images excluding failed ones - await DownloadManager._remove_failed_images_from_metadata( + await self._remove_failed_images_from_metadata( model_hash, model_name, failed_images, scanner ) # Mark as processed if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return True # Return True to indicate a remote download happened else: @@ -715,12 +666,11 @@ class DownloadManager: except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception - @staticmethod - async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner): + async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner): """Remove failed images from model metadata""" try: # Get current model data @@ -762,4 +712,13 @@ class DownloadManager: await scanner.update_single_model_cache(file_path, file_path, model_data) except Exception as e: - logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) \ No newline at end of file + logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) + + +default_download_manager = DownloadManager() + + +def get_default_download_manager() -> DownloadManager: + """Return the singleton download manager used by default routes.""" + + return default_download_manager diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 8820b49b..780eb43b 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -33,7 +33,7 @@ class MetadataUpdater: """Handles updating model metadata related to example images""" @staticmethod - async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner): + async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None): """Refresh model metadata from CivitAI Args: @@ -45,8 +45,6 @@ class MetadataUpdater: Returns: bool: True if metadata was successfully refreshed, False otherwise """ - from ..utils.example_images_download_manager import download_progress - try: # Find the model in the scanner cache cache = await scanner.get_cached_data() @@ -67,7 +65,8 @@ class MetadataUpdater: return False # Track that we're refreshing this model - download_progress['refreshed_models'].add(model_hash) + if progress is not None: + progress['refreshed_models'].add(model_hash) async def update_cache_func(old_path, new_path, metadata): return await scanner.update_single_model_cache(old_path, new_path, metadata) @@ -85,12 +84,13 @@ class MetadataUpdater: else: logger.warning(f"Failed to refresh metadata for {model_name}, {error}") return False - + except Exception as e: error_msg = f"Error refreshing metadata for {model_name}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + if progress is not None: + progress['errors'].append(error_msg) + progress['last_error'] = error_msg return False @staticmethod From 43fcce63613560b5877067b2c11ba54551b66515 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 14:40:43 +0800 Subject: [PATCH 20/24] refactor(example-images): inject websocket manager --- py/lora_manager.py | 2 +- py/routes/example_images_routes.py | 9 +- py/utils/example_images_download_manager.py | 291 ++++++++++++-------- standalone.py | 2 +- tests/routes/test_example_images_routes.py | 12 +- 5 files changed, 191 insertions(+), 125 deletions(-) diff --git a/py/lora_manager.py b/py/lora_manager.py index 1a99d508..ed37f27d 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -166,7 +166,7 @@ class LoraManager: RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) - ExampleImagesRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index d5d34218..5073410d 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -32,21 +32,24 @@ class ExampleImagesRoutes: def __init__( self, *, + ws_manager, download_manager: DownloadManager | None = None, processor=ExampleImagesProcessor, file_manager=ExampleImagesFileManager, ) -> None: - self._download_manager = download_manager or get_default_download_manager() + if ws_manager is None: + raise ValueError("ws_manager is required") + self._download_manager = download_manager or get_default_download_manager(ws_manager) self._processor = processor self._file_manager = file_manager self._handler_set: ExampleImagesHandlerSet | None = None self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None @classmethod - def setup_routes(cls, app: web.Application) -> None: + def setup_routes(cls, app: web.Application, *, ws_manager) -> None: """Register routes on the given aiohttp application using default wiring.""" - controller = cls() + controller = cls(ws_manager=ws_manager) controller.register(app) def register(self, app: web.Application) -> None: diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index e538f50a..9ddf03a4 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -5,11 +5,12 @@ import os import asyncio import json import time +from typing import Any, Dict + from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater -from ..services.websocket_manager import ws_manager # Add this import at the top from ..services.downloader import get_downloader from ..services.settings_manager import settings @@ -76,82 +77,90 @@ class _DownloadProgress(dict): class DownloadManager: """Manages downloading example images for models.""" - def __init__(self) -> None: + def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None: self._download_task: asyncio.Task | None = None self._is_downloading = False self._progress = _DownloadProgress() + self._ws_manager = ws_manager + self._state_lock = state_lock or asyncio.Lock() async def start_download(self, options: dict): """Start downloading example images for models.""" - if self._is_downloading: - raise DownloadInProgressError(self._progress.snapshot()) + async with self._state_lock: + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) - try: - data = options or {} - auto_mode = data.get('auto_mode', False) - optimize = data.get('optimize', True) - model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) + try: + data = options or {} + auto_mode = data.get('auto_mode', False) + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + delay = float(data.get('delay', 0.2)) - output_dir = settings.get('example_images_path') + output_dir = settings.get('example_images_path') - if not output_dir: - error_msg = 'Example images path not configured in settings' - if auto_mode: - logger.debug(error_msg) - return { - 'success': True, - 'message': 'Example images path not configured, skipping auto download' - } - raise DownloadConfigurationError(error_msg) + if not output_dir: + error_msg = 'Example images path not configured in settings' + if auto_mode: + logger.debug(error_msg) + return { + 'success': True, + 'message': 'Example images path not configured, skipping auto download' + } + raise DownloadConfigurationError(error_msg) - os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) - self._progress.reset() - self._progress['status'] = 'running' - self._progress['start_time'] = time.time() - self._progress['end_time'] = None + self._progress.reset() + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None - progress_file = os.path.join(output_dir, '.download_progress.json') - if os.path.exists(progress_file): - try: - with open(progress_file, 'r', encoding='utf-8') as f: - saved_progress = json.load(f) - self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) - self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) - logger.debug( - "Loaded previous progress, %s models already processed, %s models marked as failed", - len(self._progress['processed_models']), - len(self._progress['failed_models']), - ) - except Exception as e: - logger.error(f"Failed to load progress file: {e}") + progress_file = os.path.join(output_dir, '.download_progress.json') + if os.path.exists(progress_file): + try: + with open(progress_file, 'r', encoding='utf-8') as f: + saved_progress = json.load(f) + self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) + self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) + logger.debug( + "Loaded previous progress, %s models already processed, %s models marked as failed", + len(self._progress['processed_models']), + len(self._progress['failed_models']), + ) + except Exception as e: + logger.error(f"Failed to load progress file: {e}") + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() + else: self._progress['processed_models'] = set() self._progress['failed_models'] = set() - else: - self._progress['processed_models'] = set() - self._progress['failed_models'] = set() - self._is_downloading = True - self._download_task = asyncio.create_task( - self._download_all_example_images( - output_dir, - optimize, - model_types, - delay + self._is_downloading = True + self._download_task = asyncio.create_task( + self._download_all_example_images( + output_dir, + optimize, + model_types, + delay + ) ) - ) - return { - 'success': True, - 'message': 'Download started', - 'status': self._progress.snapshot() - } + snapshot = self._progress.snapshot() + except Exception as e: + self._is_downloading = False + self._download_task = None + logger.error(f"Failed to start example images download: {e}", exc_info=True) + raise ExampleImagesDownloadError(str(e)) from e - except Exception as e: - logger.error(f"Failed to start example images download: {e}", exc_info=True) - raise ExampleImagesDownloadError(str(e)) from e + await self._broadcast_progress(status='running') + + return { + 'success': True, + 'message': 'Download started', + 'status': snapshot + } async def get_status(self, request): """Get the current status of example images download.""" @@ -165,10 +174,13 @@ class DownloadManager: async def pause_download(self, request): """Pause the example images download.""" - if not self._is_downloading: - raise DownloadNotRunningError() + async with self._state_lock: + if not self._is_downloading: + raise DownloadNotRunningError() - self._progress['status'] = 'paused' + self._progress['status'] = 'paused' + + await self._broadcast_progress(status='paused') return { 'success': True, @@ -178,20 +190,23 @@ class DownloadManager: async def resume_download(self, request): """Resume the example images download.""" - if not self._is_downloading: - raise DownloadNotRunningError() + async with self._state_lock: + if not self._is_downloading: + raise DownloadNotRunningError() - if self._progress['status'] == 'paused': - self._progress['status'] = 'running' + if self._progress['status'] == 'paused': + self._progress['status'] = 'running' + else: + raise DownloadNotRunningError( + f"Download is in '{self._progress['status']}' state, cannot resume" + ) - return { - 'success': True, - 'message': 'Download resumed' - } + await self._broadcast_progress(status='running') - raise DownloadNotRunningError( - f"Download is in '{self._progress['status']}' state, cannot resume" - ) + return { + 'success': True, + 'message': 'Download resumed' + } async def _download_all_example_images(self, output_dir, optimize, model_types, delay): """Download example images for all models.""" @@ -225,6 +240,7 @@ class DownloadManager: # Update total count self._progress['total'] = len(all_models) logger.debug(f"Found {self._progress['total']} models to process") + await self._broadcast_progress(status='running') # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): @@ -236,6 +252,7 @@ class DownloadManager: # Update progress self._progress['completed'] += 1 + await self._broadcast_progress(status='running') # Only add delay after remote download of models, and not after processing the last model if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': @@ -244,8 +261,13 @@ class DownloadManager: # Mark as completed self._progress['status'] = 'completed' self._progress['end_time'] = time.time() - logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") - + logger.debug( + "Example images download completed: %s/%s models processed", + self._progress['completed'], + self._progress['total'], + ) + await self._broadcast_progress(status='completed') + except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) @@ -253,7 +275,8 @@ class DownloadManager: self._progress['last_error'] = error_msg self._progress['status'] = 'error' self._progress['end_time'] = time.time() - + await self._broadcast_progress(status='error', extra={'error': error_msg}) + finally: # Save final progress to file try: @@ -262,8 +285,9 @@ class DownloadManager: logger.error(f"Failed to save progress file: {e}") # Set download status to not downloading - self._is_downloading = False - self._download_task = None + async with self._state_lock: + self._is_downloading = False + self._download_task = None async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): """Process a single model download.""" @@ -285,6 +309,7 @@ class DownloadManager: try: # Update current model info self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') # Skip if already in failed models if model_hash in self._progress['failed_models']: @@ -414,10 +439,10 @@ class DownloadManager: async def start_force_download(self, options: dict): """Force download example images for specific models.""" - if self._is_downloading: - raise DownloadInProgressError(self._progress.snapshot()) + async with self._state_lock: + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) - try: data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) @@ -442,6 +467,9 @@ class DownloadManager: self._is_downloading = True + await self._broadcast_progress(status='running') + + try: result = await self._download_specific_models_example_images_sync( model_hashes, output_dir, @@ -450,7 +478,8 @@ class DownloadManager: delay ) - self._is_downloading = False + async with self._state_lock: + self._is_downloading = False return { 'success': True, @@ -459,8 +488,10 @@ class DownloadManager: } except Exception as e: - self._is_downloading = False + async with self._state_lock: + self._is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) + await self._broadcast_progress(status='error', extra={'error': str(e)}) raise ExampleImagesDownloadError(str(e)) from e async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): @@ -495,15 +526,9 @@ class DownloadManager: # Update total count based on found models self._progress['total'] = len(models_to_process) logger.debug(f"Found {self._progress['total']} models to process") - + # Send initial progress via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': 0, - 'total': self._progress['total'], - 'status': 'running', - 'current_model': '' - }) + await self._broadcast_progress(status='running') # Process each model success_count = 0 @@ -519,15 +544,9 @@ class DownloadManager: # Update progress self._progress['completed'] += 1 - + # Send progress update via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': 'running', - 'current_model': self._progress['current_model'] - }) + await self._broadcast_progress(status='running') # Only add delay after remote download, and not after processing the last model if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': @@ -536,16 +555,14 @@ class DownloadManager: # Mark as completed self._progress['status'] = 'completed' self._progress['end_time'] = time.time() - logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") - + logger.debug( + "Forced example images download completed: %s/%s models processed", + self._progress['completed'], + self._progress['total'], + ) + # Send final progress via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': 'completed', - 'current_model': '' - }) + await self._broadcast_progress(status='completed') return { 'total': self._progress['total'], @@ -561,16 +578,9 @@ class DownloadManager: self._progress['last_error'] = error_msg self._progress['status'] = 'error' self._progress['end_time'] = time.time() - + # Send error status via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': self._progress['completed'], - 'total': self._progress['total'], - 'status': 'error', - 'error': error_msg, - 'current_model': '' - }) + await self._broadcast_progress(status='error', extra={'error': error_msg}) raise @@ -598,6 +608,7 @@ class DownloadManager: try: # Update current model info self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -714,11 +725,53 @@ class DownloadManager: except Exception as e: logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) + async def _broadcast_progress( + self, + *, + status: str | None = None, + extra: Dict[str, Any] | None = None, + ) -> None: + payload = self._build_progress_payload(status=status, extra=extra) + try: + await self._ws_manager.broadcast(payload) + except Exception as exc: # pragma: no cover - defensive logging + logger.warning("Failed to broadcast example image progress: %s", exc) -default_download_manager = DownloadManager() + def _build_progress_payload( + self, + *, + status: str | None = None, + extra: Dict[str, Any] | None = None, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = { + 'type': 'example_images_progress', + 'processed': self._progress['completed'], + 'total': self._progress['total'], + 'status': status or self._progress['status'], + 'current_model': self._progress['current_model'], + } + + if self._progress['errors']: + payload['errors'] = list(self._progress['errors']) + if self._progress['last_error']: + payload['last_error'] = self._progress['last_error'] + + if extra: + payload.update(extra) + + return payload -def get_default_download_manager() -> DownloadManager: +_default_download_manager: DownloadManager | None = None + + +def get_default_download_manager(ws_manager) -> DownloadManager: """Return the singleton download manager used by default routes.""" - return default_download_manager + global _default_download_manager + if ( + _default_download_manager is None + or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager + ): + _default_download_manager = DownloadManager(ws_manager=ws_manager) + return _default_download_manager diff --git a/standalone.py b/standalone.py index a6259851..95c45ca7 100644 --- a/standalone.py +++ b/standalone.py @@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager): RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) - ExampleImagesRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index e921e744..9a316499 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -3,7 +3,7 @@ from __future__ import annotations import json from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple from aiohttp import web from aiohttp.test_utils import TestClient, TestServer @@ -88,6 +88,14 @@ class StubExampleImagesFileManager: return web.json_response({"operation": "has_images", "query": dict(request.query)}) +class StubWebSocketManager: + def __init__(self) -> None: + self.broadcast_calls: List[Dict[str, Any]] = [] + + async def broadcast(self, payload: Dict[str, Any]) -> None: + self.broadcast_calls.append(payload) + + @asynccontextmanager async def example_images_app() -> ExampleImagesHarness: """Yield an ExampleImagesRoutes app wired with stubbed collaborators.""" @@ -95,8 +103,10 @@ async def example_images_app() -> ExampleImagesHarness: download_manager = StubDownloadManager() processor = StubExampleImagesProcessor() file_manager = StubExampleImagesFileManager() + ws_manager = StubWebSocketManager() controller = ExampleImagesRoutes( + ws_manager=ws_manager, download_manager=download_manager, processor=processor, file_manager=file_manager, From e128c80eb12b976a6bb272828cccc84338163e68 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 14:58:35 +0800 Subject: [PATCH 21/24] test(services): add async example image download tests --- ...t_example_images_download_manager_async.py | 228 ++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 tests/services/test_example_images_download_manager_async.py diff --git a/tests/services/test_example_images_download_manager_async.py b/tests/services/test_example_images_download_manager_async.py new file mode 100644 index 00000000..7eef56fb --- /dev/null +++ b/tests/services/test_example_images_download_manager_async.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest + +from py.services.settings_manager import settings +from py.utils import example_images_download_manager as download_module + + +class RecordingWebSocketManager: + """Collects broadcast payloads for assertions.""" + + def __init__(self) -> None: + self.payloads: list[dict] = [] + + async def broadcast(self, payload: dict) -> None: + self.payloads.append(payload) + + +class StubScanner: + """Scanner double returning predetermined cache contents.""" + + def __init__(self, models: list[dict]) -> None: + self._cache = SimpleNamespace(raw_data=models) + + async def get_cached_data(self): + return self._cache + + +def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> None: + async def _get_lora_scanner(cls): + return scanner + + monkeypatch.setattr( + download_module.ServiceRegistry, + "get_lora_scanner", + classmethod(_get_lora_scanner), + ) + + +@pytest.mark.usefixtures("tmp_path") +async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + + model = { + "sha256": "abc123", + "model_name": "Example", + "file_path": str(tmp_path / "example.safetensors"), + "file_name": "example.safetensors", + } + _patch_scanner(monkeypatch, StubScanner([model])) + + started = asyncio.Event() + release = asyncio.Event() + + async def fake_process_local_examples(*_args, **_kwargs): + started.set() + await release.wait() + return True + + async def fake_update_metadata(*_args, **_kwargs): + return True + + async def fake_get_downloader(): + return object() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + try: + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + await asyncio.wait_for(started.wait(), timeout=1) + + with pytest.raises(download_module.DownloadInProgressError) as exc: + await manager.start_download({"model_types": ["lora"], "delay": 0}) + + snapshot = exc.value.progress_snapshot + assert snapshot["status"] == "running" + assert snapshot["current_model"] == "Example (abc123)" + + statuses = [payload["status"] for payload in ws_manager.payloads] + assert "running" in statuses + + finally: + release.set() + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + +@pytest.mark.usefixtures("tmp_path") +async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + + models = [ + { + "sha256": "hash-one", + "model_name": "Model One", + "file_path": str(tmp_path / "model-one.safetensors"), + "file_name": "model-one.safetensors", + "civitai": {"images": [{"url": "https://example.com/one.png"}]}, + }, + { + "sha256": "hash-two", + "model_name": "Model Two", + "file_path": str(tmp_path / "model-two.safetensors"), + "file_name": "model-two.safetensors", + "civitai": {"images": [{"url": "https://example.com/two.png"}]}, + }, + ] + _patch_scanner(monkeypatch, StubScanner(models)) + + async def fake_process_local_examples(*_args, **_kwargs): + return False + + async def fake_update_metadata(*_args, **_kwargs): + return True + + first_call_started = asyncio.Event() + first_release = asyncio.Event() + second_call_started = asyncio.Event() + call_order: list[str] = [] + + async def fake_download_model_images(model_hash, *_args, **_kwargs): + call_order.append(model_hash) + if len(call_order) == 1: + first_call_started.set() + await first_release.wait() + else: + second_call_started.set() + return True, False + + async def fake_get_downloader(): + class _Downloader: + async def download_to_memory(self, *_a, **_kw): + return True, b"", {} + + return _Downloader() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "download_model_images", + staticmethod(fake_download_model_images), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + original_sleep = download_module.asyncio.sleep + pause_gate = asyncio.Event() + resume_gate = asyncio.Event() + + async def fake_sleep(delay: float): + if delay == 1: + pause_gate.set() + await resume_gate.wait() + else: + await original_sleep(delay) + + monkeypatch.setattr(download_module.asyncio, "sleep", fake_sleep) + + try: + await manager.start_download({"model_types": ["lora"], "delay": 0}) + + await asyncio.wait_for(first_call_started.wait(), timeout=1) + + await manager.pause_download({}) + + first_release.set() + + await asyncio.wait_for(pause_gate.wait(), timeout=1) + assert manager._progress["status"] == "paused" + assert not second_call_started.is_set() + + statuses = [payload["status"] for payload in ws_manager.payloads] + paused_index = statuses.index("paused") + + await asyncio.sleep(0) + assert not second_call_started.is_set() + + await manager.resume_download({}) + resume_gate.set() + + await asyncio.wait_for(second_call_started.wait(), timeout=1) + + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + statuses_after = [payload["status"] for payload in ws_manager.payloads] + running_after = next( + i for i, status in enumerate(statuses_after[paused_index + 1 :], start=paused_index + 1) if status == "running" + ) + assert running_after > paused_index + assert "completed" in statuses_after[running_after:] + assert call_order == ["hash-one", "hash-two"] + + finally: + first_release.set() + resume_gate.set() + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep) From 3eacf9558a21c6be5feda5152d27cd5b0c437c82 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 23 Sep 2025 15:39:56 +0800 Subject: [PATCH 22/24] docs: remove outdated developer notes and add example image route architecture documentation --- README.md | 26 ------ docs/architecture/example_images_routes.md | 93 ++++++++++++++++++++++ 2 files changed, 93 insertions(+), 26 deletions(-) create mode 100644 docs/architecture/example_images_routes.md diff --git a/README.md b/README.md index f9c1806b..7e932acc 100644 --- a/README.md +++ b/README.md @@ -233,32 +233,6 @@ You can now run LoRA Manager independently from ComfyUI: This standalone mode provides a lightweight option for managing your model and recipe collection without needing to run the full ComfyUI environment, making it useful even for users who primarily use other stable diffusion interfaces. -## Developer notes - -The REST layer is split into modular registrars, controllers, and handler sets -to simplify maintenance: - -* `py/routes/recipe_route_registrar.py` holds the declarative endpoint list. -* `py/routes/base_recipe_routes.py` wires shared services/templates and returns - the handler mapping consumed by `RecipeRouteRegistrar`. -* `py/routes/handlers/recipe_handlers.py` groups HTTP adapters by concern (page - rendering, listings, queries, mutations, sharing) and delegates business rules - to services in `py/services/recipes/`. - -To add a new recipe endpoint: - -1. Declare the route in `ROUTE_DEFINITIONS` with a unique handler name. -2. Implement the coroutine on the appropriate handler class or introduce a new - handler when the concern does not fit existing ones. -3. Inject additional collaborators in - `BaseRecipeRoutes._create_handler_set` (for example a new service or factory) - so the handler can access its dependencies. - -The end-to-end wiring is documented in -[`docs/architecture/recipe_routes.md`](docs/architecture/recipe_routes.md), and -the integration suite in `tests/routes/test_recipe_routes.py` smoke-tests the -primary endpoints. - --- ## Contributing diff --git a/docs/architecture/example_images_routes.md b/docs/architecture/example_images_routes.md new file mode 100644 index 00000000..128530f6 --- /dev/null +++ b/docs/architecture/example_images_routes.md @@ -0,0 +1,93 @@ +# Example image route architecture + +The example image routing stack mirrors the layered model route stack described in +[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup, +handler orchestration, and long-running workflows now live in clearly separated modules so +we can extend download/import behaviour without touching the entire feature surface. + +```mermaid +graph TD + subgraph HTTP + A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller] + end + subgraph Application + B --> C[ExampleImagesHandlerSet] + C --> D1[Handlers] + D1 --> E1[Use cases] + E1 --> F1[Download manager / processor / file manager] + end + subgraph Side Effects + F1 --> G1[Filesystem] + F1 --> G2[Model metadata] + F1 --> G3[WebSocket progress] + end +``` + +## Layer responsibilities + +| Layer | Module(s) | Responsibility | +| --- | --- | --- | +| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. | +| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. | +| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. | +| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. | +| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. | + +## Handler responsibilities & invariants + +`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}` +mapping consumed by the registrar. The table below outlines how each handler collaborates +with the use cases and utilities. + +| Handler | Key endpoints | Collaborators | Contracts | +| --- | --- | --- | --- | +| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. | +| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. | +| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. | + +## Use case boundaries + +| Use case | Entry point | Dependencies | Guarantees | +| --- | --- | --- | --- | +| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. | +| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. | + +## Maintaining critical invariants + +* **Shared progress snapshots** - The download handler returns the same snapshot built by + `DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket + progress events. +* **Safe filesystem access** - All folder/file actions flow through + `ExampleImagesFileManager`, which validates the configured example image root and ensures + responses never leak absolute paths outside the allowed directory. +* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`, + which updates model metadata via `MetadataManager` and notifies the relevant scanners so + cache state stays in sync. + +## Migration notes + +The refactor brings the example image stack in line with the model/recipe stacks: + +1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects + should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring + handler callables. +2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like + `BaseModelRoutes`. If you previously instantiated handlers directly, inject custom + collaborators via the controller constructor (`download_manager`, `processor`, + `file_manager`) to keep test seams predictable. +3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching + `DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers + expect those abstractions to surface validation/concurrency errors, and bypassing them + will skip the HTTP-friendly error mapping. + +## Extending the stack + +1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`. +2. Expose the coroutine on an existing handler class (or create a new handler and extend + `ExampleImagesHandlerSet`). +3. Wire additional services or factories inside `_build_handler_set` on + `ExampleImagesRoutes`, mirroring how the model stack introduces new use cases. + +`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause +flows, and import validations. Use it as a template when introducing new handler +collaborators or error mappings. From 8c9bb358247635fe4dd94bbad96c7f10a7e3f8c4 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 19:25:02 +0800 Subject: [PATCH 23/24] Update tests/services/test_base_model_service.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/services/test_base_model_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index c3fdc884..fc28a54e 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -1,4 +1,4 @@ -import pytest +import pytest import importlib import importlib.util From 6054d95e8518bb6b7a1f701f4992e660b162c055 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 19:25:12 +0800 Subject: [PATCH 24/24] Update py/services/model_query.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- py/services/model_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/services/model_query.py b/py/services/model_query.py index 08ca652f..df7bb67a 100644 --- a/py/services/model_query.py +++ b/py/services/model_query.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable