diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 458a5e87..65103ece 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -8,12 +8,20 @@ import jinja2 from aiohttp import web from ..config import config -from ..services.metadata_service import get_default_metadata_provider +from ..services.download_coordinator import DownloadCoordinator +from ..services.downloader import get_downloader +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider +from ..services.metadata_sync_service import MetadataSyncService from ..services.model_file_service import ModelFileService, ModelMoveService +from ..services.preview_asset_service import PreviewAssetService +from ..services.server_i18n import server_i18n as default_server_i18n +from ..services.service_registry import ServiceRegistry from ..services.settings_manager import settings as default_settings +from ..services.tag_update_service import TagUpdateService from ..services.websocket_manager import ws_manager as default_ws_manager from ..services.websocket_progress_callback import WebSocketProgressCallback -from ..services.server_i18n import server_i18n as default_server_i18n +from ..utils.exif_utils import ExifUtils +from ..utils.metadata_manager import MetadataManager from ..utils.routes_common import ModelRouteUtils from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar from .handlers.model_handlers import ( @@ -64,6 +72,24 @@ class BaseModelRoutes(ABC): self._handler_set: ModelHandlerSet | None = None self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None + self._preview_service = PreviewAssetService( + metadata_manager=MetadataManager, + downloader_factory=get_downloader, + exif_utils=ExifUtils, + ) + self._metadata_sync_service = MetadataSyncService( + metadata_manager=MetadataManager, + preview_service=self._preview_service, + settings=settings_service, + default_metadata_provider_factory=metadata_provider_factory, + metadata_provider_selector=get_metadata_provider, + ) + self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager) + self._download_coordinator = DownloadCoordinator( + ws_manager=self._ws_manager, + download_manager_factory=ServiceRegistry.get_download_manager, + ) + if service is not None: self.attach_service(service) @@ -98,9 +124,19 @@ class BaseModelRoutes(ABC): parse_specific_params=self._parse_specific_params, logger=logger, ) - management = ModelManagementHandler(service=service, logger=logger) + management = ModelManagementHandler( + service=service, + logger=logger, + metadata_sync=self._metadata_sync_service, + preview_service=self._preview_service, + tag_update_service=self._tag_update_service, + ) query = ModelQueryHandler(service=service, logger=logger) - download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger) + download = ModelDownloadHandler( + ws_manager=self._ws_manager, + logger=logger, + download_coordinator=self._download_coordinator, + ) civitai = ModelCivitaiHandler( service=service, settings_service=self._settings, @@ -110,6 +146,7 @@ class BaseModelRoutes(ABC): validate_model_type=self._validate_civitai_model_type, expected_model_types=self._get_expected_model_types, find_model_file=self._find_model_file, + metadata_sync=self._metadata_sync_service, ) move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger) auto_organize = ModelAutoOrganizeHandler( diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 66a7123a..5f9eaf3b 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -4,16 +4,23 @@ from __future__ import annotations import asyncio import json import logging +import os from dataclasses import dataclass from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional from aiohttp import web import jinja2 +from ...config import config +from ...services.download_coordinator import DownloadCoordinator +from ...services.metadata_sync_service import MetadataSyncService from ...services.model_file_service import ModelFileService, ModelMoveService -from ...services.websocket_progress_callback import WebSocketProgressCallback -from ...services.websocket_manager import WebSocketManager +from ...services.preview_asset_service import PreviewAssetService from ...services.settings_manager import SettingsManager +from ...services.tag_update_service import TagUpdateService +from ...services.websocket_manager import WebSocketManager +from ...services.websocket_progress_callback import WebSocketProgressCallback +from ...utils.file_utils import calculate_sha256 from ...utils.routes_common import ModelRouteUtils @@ -168,9 +175,20 @@ class ModelListingHandler: class ModelManagementHandler: """Handle mutation operations on models.""" - def __init__(self, *, service, logger: logging.Logger) -> None: + def __init__( + self, + *, + service, + logger: logging.Logger, + metadata_sync: MetadataSyncService, + preview_service: PreviewAssetService, + tag_update_service: TagUpdateService, + ) -> None: self._service = service self._logger = logger + self._metadata_sync = metadata_sync + self._preview_service = preview_service + self._tag_update_service = tag_update_service async def delete_model(self, request: web.Request) -> web.Response: return await ModelRouteUtils.handle_delete_model(request, self._service.scanner) @@ -192,7 +210,7 @@ class ModelManagementHandler: if not model_data.get("sha256"): return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) - success, error = await ModelRouteUtils.fetch_and_update_model( + success, error = await self._metadata_sync.fetch_and_update_model( sha256=model_data["sha256"], file_path=file_path, model_data=model_data, @@ -208,16 +226,144 @@ class ModelManagementHandler: return web.json_response({"success": False, "error": str(exc)}, status=500) async def relink_civitai(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + model_id = data.get("model_id") + model_version_id = data.get("model_version_id") + + if not file_path or model_id is None: + return web.json_response( + {"success": False, "error": "Both file_path and model_id are required"}, + status=400, + ) + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + local_metadata = await self._metadata_sync.load_local_metadata(metadata_path) + + updated_metadata = await self._metadata_sync.relink_metadata( + file_path=file_path, + metadata=local_metadata, + model_id=int(model_id), + model_version_id=int(model_version_id) if model_version_id else None, + ) + + await self._service.scanner.update_single_model_cache( + file_path, file_path, updated_metadata + ) + + message = ( + f"Model successfully re-linked to Civitai model {model_id}" + + (f" version {model_version_id}" if model_version_id else "") + ) + return web.json_response( + {"success": True, "message": message, "hash": updated_metadata.get("sha256", "")} + ) + except Exception as exc: + self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def replace_preview(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner) + try: + reader = await request.multipart() + + field = await reader.next() + if field is None or field.name != "preview_file": + raise ValueError("Expected 'preview_file' field") + content_type = field.headers.get("Content-Type", "image/png") + content_disposition = field.headers.get("Content-Disposition", "") + + original_filename = None + import re + + match = re.search(r'filename="(.*?)"', content_disposition) + if match: + original_filename = match.group(1) + + preview_data = await field.read() + + field = await reader.next() + if field is None or field.name != "model_path": + raise ValueError("Expected 'model_path' field") + model_path = (await field.read()).decode() + + nsfw_level = 0 + field = await reader.next() + if field and field.name == "nsfw_level": + try: + nsfw_level = int((await field.read()).decode()) + except (ValueError, TypeError): + self._logger.warning("Invalid NSFW level format, using default 0") + + result = await self._preview_service.replace_preview( + model_path=model_path, + preview_data=preview_data, + content_type=content_type, + original_filename=original_filename, + nsfw_level=nsfw_level, + update_preview_in_cache=self._service.scanner.update_preview_in_cache, + metadata_loader=self._metadata_sync.load_local_metadata, + ) + + return web.json_response( + { + "success": True, + "preview_url": config.get_preview_static_url(result["preview_path"]), + "preview_nsfw_level": result["preview_nsfw_level"], + } + ) + except Exception as exc: + self._logger.error("Error replacing preview: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def save_metadata(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="File path is required", status=400) + + metadata_updates = {k: v for k, v in data.items() if k != "file_path"} + + await self._metadata_sync.save_metadata_updates( + file_path=file_path, + updates=metadata_updates, + metadata_loader=self._metadata_sync.load_local_metadata, + update_cache=self._service.scanner.update_single_model_cache, + ) + + if "model_name" in metadata_updates: + cache = await self._service.scanner.get_cached_data() + await cache.resort() + + return web.json_response({"success": True}) + except Exception as exc: + self._logger.error("Error saving metadata: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def add_tags(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_add_tags(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + new_tags = data.get("tags", []) + + if not file_path: + return web.Response(text="File path is required", status=400) + + if not isinstance(new_tags, list): + return web.Response(text="Tags must be a list", status=400) + + tags = await self._tag_update_service.add_tags( + file_path=file_path, + new_tags=new_tags, + metadata_loader=self._metadata_sync.load_local_metadata, + update_cache=self._service.scanner.update_single_model_cache, + ) + + return web.json_response({"success": True, "tags": tags}) + except Exception as exc: + self._logger.error("Error adding tags: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def rename_model(self, request: web.Request) -> web.Response: return await ModelRouteUtils.handle_rename_model(request, self._service.scanner) @@ -226,7 +372,27 @@ class ModelManagementHandler: return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner) async def verify_duplicates(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner) + try: + data = await request.json() + file_paths = data.get("file_paths", []) + + if not file_paths: + return web.json_response( + {"success": False, "error": "No file paths provided for verification"}, + status=400, + ) + + results = await self._metadata_sync.verify_duplicate_hashes( + file_paths=file_paths, + metadata_loader=self._metadata_sync.load_local_metadata, + hash_calculator=calculate_sha256, + update_cache=self._service.scanner.update_single_model_cache, + ) + + return web.json_response({"success": True, **results}) + except Exception as exc: + self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) class ModelQueryHandler: @@ -429,12 +595,39 @@ class ModelQueryHandler: class ModelDownloadHandler: """Coordinate downloads and progress reporting.""" - def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None: + def __init__( + self, + *, + ws_manager: WebSocketManager, + logger: logging.Logger, + download_coordinator: DownloadCoordinator, + ) -> None: self._ws_manager = ws_manager self._logger = logger + self._download_coordinator = download_coordinator async def download_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request) + try: + payload = await request.json() + result = await self._download_coordinator.schedule_download(payload) + if not result.get("success", False): + return web.json_response(result, status=500) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + error_message = str(exc) + if "401" in error_message: + self._logger.warning("Early access error (401): %s", error_message) + return web.json_response( + { + "success": False, + "error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.", + }, + status=401, + ) + self._logger.error("Error downloading model: %s", error_message) + return web.json_response({"success": False, "error": error_message}, status=500) async def download_model_get(self, request: web.Request) -> web.Response: try: @@ -460,7 +653,12 @@ class ModelDownloadHandler: future.set_result(data) mock_request = type("MockRequest", (), {"json": lambda self=None: future})() - return await ModelRouteUtils.handle_download_model(mock_request) + result = await self._download_coordinator.schedule_download(data) + if not result.get("success", False): + return web.json_response(result, status=500) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) except Exception as exc: self._logger.error("Error downloading model via GET: %s", exc, exc_info=True) return web.Response(status=500, text=str(exc)) @@ -470,8 +668,8 @@ class ModelDownloadHandler: download_id = request.query.get("download_id") if not download_id: return web.json_response({"success": False, "error": "Download ID is required"}, status=400) - mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})() - return await ModelRouteUtils.handle_cancel_download(mock_request) + result = await self._download_coordinator.cancel_download(download_id) + return web.json_response(result) except Exception as exc: self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -504,6 +702,7 @@ class ModelCivitaiHandler: validate_model_type: Callable[[str], bool], expected_model_types: Callable[[], str], find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]], + metadata_sync: MetadataSyncService, ) -> None: self._service = service self._settings = settings_service @@ -513,6 +712,7 @@ class ModelCivitaiHandler: self._validate_model_type = validate_model_type self._expected_model_types = expected_model_types self._find_model_file = find_model_file + self._metadata_sync = metadata_sync async def fetch_all_civitai(self, request: web.Request) -> web.Response: try: @@ -545,7 +745,7 @@ class ModelCivitaiHandler: for model in to_process: try: original_name = model.get("model_name") - result, error = await ModelRouteUtils.fetch_and_update_model( + result, error = await self._metadata_sync.fetch_and_update_model( sha256=model["sha256"], file_path=model["file_path"], model_data=model, diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py new file mode 100644 index 00000000..4cf866e5 --- /dev/null +++ b/py/services/download_coordinator.py @@ -0,0 +1,100 @@ +"""Service wrapper for coordinating download lifecycle events.""" + +from __future__ import annotations + +import logging +from typing import Any, Awaitable, Callable, Dict, Optional + + +logger = logging.getLogger(__name__) + + +class DownloadCoordinator: + """Manage download scheduling, cancellation and introspection.""" + + def __init__( + self, + *, + ws_manager, + download_manager_factory: Callable[[], Awaitable], + ) -> None: + self._ws_manager = ws_manager + self._download_manager_factory = download_manager_factory + + async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Schedule a download using the provided payload.""" + + download_manager = await self._download_manager_factory() + + download_id = payload.get("download_id") or self._ws_manager.generate_download_id() + payload.setdefault("download_id", download_id) + + async def progress_callback(progress: Any) -> None: + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "progress", + "progress": progress, + "download_id": download_id, + }, + ) + + model_id = self._parse_optional_int(payload.get("model_id"), "model_id") + model_version_id = self._parse_optional_int( + payload.get("model_version_id"), "model_version_id" + ) + + if model_id is None and model_version_id is None: + raise ValueError( + "Missing required parameter: Please provide either 'model_id' or 'model_version_id'" + ) + + result = await download_manager.download_from_civitai( + model_id=model_id, + model_version_id=model_version_id, + save_dir=payload.get("model_root"), + relative_path=payload.get("relative_path", ""), + use_default_paths=payload.get("use_default_paths", False), + progress_callback=progress_callback, + download_id=download_id, + source=payload.get("source"), + ) + + result["download_id"] = download_id + return result + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + """Cancel an active download and emit a broadcast event.""" + + download_manager = await self._download_manager_factory() + result = await download_manager.cancel_download(download_id) + + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "cancelled", + "progress": 0, + "download_id": download_id, + "message": "Download cancelled by user", + }, + ) + + return result + + async def list_active_downloads(self) -> Dict[str, Any]: + """Return the active download map from the underlying manager.""" + + download_manager = await self._download_manager_factory() + return await download_manager.get_active_downloads() + + def _parse_optional_int(self, value: Any, field: str) -> Optional[int]: + """Parse an optional integer from user input.""" + + if value is None or value == "": + return None + + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid {field}: Must be an integer") from exc + diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py new file mode 100644 index 00000000..aaf2f248 --- /dev/null +++ b/py/services/metadata_sync_service.py @@ -0,0 +1,355 @@ +"""Services for synchronising metadata with remote providers.""" + +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, Iterable, Optional + +from ..services.settings_manager import SettingsManager +from ..utils.model_utils import determine_base_model + +logger = logging.getLogger(__name__) + + +class MetadataProviderProtocol: + """Subset of metadata provider interface consumed by the sync service.""" + + async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]: + ... + + async def get_model_version( + self, model_id: int, model_version_id: Optional[int] + ) -> Optional[Dict[str, Any]]: + ... + + +class MetadataSyncService: + """High level orchestration for metadata synchronisation flows.""" + + def __init__( + self, + *, + metadata_manager, + preview_service, + settings: SettingsManager, + default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]], + metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]], + ) -> None: + self._metadata_manager = metadata_manager + self._preview_service = preview_service + self._settings = settings + self._get_default_provider = default_metadata_provider_factory + self._get_provider = metadata_provider_selector + + async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]: + """Load metadata JSON from disk, returning an empty structure when missing.""" + + if not os.path.exists(metadata_path): + return {} + + try: + with open(metadata_path, "r", encoding="utf-8") as handle: + return json.load(handle) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error loading metadata from %s: %s", metadata_path, exc) + return {} + + async def mark_not_found_on_civitai( + self, metadata_path: str, local_metadata: Dict[str, Any] + ) -> None: + """Persist the not-found flag for a metadata payload.""" + + local_metadata["from_civitai"] = False + await self._metadata_manager.save_metadata(metadata_path, local_metadata) + + @staticmethod + def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool: + """Determine if the metadata originated from the CivitAI public API.""" + + if not isinstance(meta, dict): + return False + files = meta.get("files") + images = meta.get("images") + source = meta.get("source") + return bool(files) and bool(images) and source != "archive_db" + + async def update_model_metadata( + self, + metadata_path: str, + local_metadata: Dict[str, Any], + civitai_metadata: Dict[str, Any], + metadata_provider: Optional[MetadataProviderProtocol] = None, + ) -> Dict[str, Any]: + """Merge remote metadata into the local record and persist the result.""" + + existing_civitai = local_metadata.get("civitai") or {} + + if ( + civitai_metadata.get("source") == "archive_db" + and self.is_civitai_api_metadata(existing_civitai) + ): + logger.info( + "Skip civitai update for %s (%s)", + local_metadata.get("model_name", ""), + existing_civitai.get("name", ""), + ) + else: + merged_civitai = existing_civitai.copy() + merged_civitai.update(civitai_metadata) + + if civitai_metadata.get("source") == "archive_db": + model_name = civitai_metadata.get("model", {}).get("name", "") + version_name = civitai_metadata.get("name", "") + logger.info( + "Recovered metadata from archive_db for deleted model: %s (%s)", + model_name, + version_name, + ) + + if "trainedWords" in existing_civitai: + existing_trained = existing_civitai.get("trainedWords", []) + new_trained = civitai_metadata.get("trainedWords", []) + merged_trained = list(set(existing_trained + new_trained)) + merged_civitai["trainedWords"] = merged_trained + + local_metadata["civitai"] = merged_civitai + + if "model" in civitai_metadata and civitai_metadata["model"]: + model_data = civitai_metadata["model"] + + if model_data.get("name"): + local_metadata["model_name"] = model_data["name"] + + if not local_metadata.get("modelDescription") and model_data.get("description"): + local_metadata["modelDescription"] = model_data["description"] + + if not local_metadata.get("tags") and model_data.get("tags"): + local_metadata["tags"] = model_data["tags"] + + if model_data.get("creator") and not local_metadata.get("civitai", {}).get( + "creator" + ): + local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"] + + local_metadata["base_model"] = determine_base_model( + civitai_metadata.get("baseModel") + ) + + await self._preview_service.ensure_preview_for_metadata( + metadata_path, local_metadata, civitai_metadata.get("images", []) + ) + + await self._metadata_manager.save_metadata(metadata_path, local_metadata) + return local_metadata + + async def fetch_and_update_model( + self, + *, + sha256: str, + file_path: str, + model_data: Dict[str, Any], + update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> tuple[bool, Optional[str]]: + """Fetch metadata for a model and update both disk and cache state.""" + + if not isinstance(model_data, dict): + error = f"Invalid model_data type: {type(model_data)}" + logger.error(error) + return False, error + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + enable_archive = self._settings.get("enable_metadata_archive_db", False) + + try: + if model_data.get("civitai_deleted") is True: + if not enable_archive or model_data.get("db_checked") is True: + return ( + False, + "CivitAI model is deleted and metadata archive DB is not enabled", + ) + metadata_provider = await self._get_provider("sqlite") + else: + metadata_provider = await self._get_default_provider() + + civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256) + if not civitai_metadata: + if error == "Model not found": + model_data["from_civitai"] = False + model_data["civitai_deleted"] = True + model_data["db_checked"] = enable_archive + model_data["last_checked_at"] = datetime.now().timestamp() + + data_to_save = model_data.copy() + data_to_save.pop("folder", None) + await self._metadata_manager.save_metadata(file_path, data_to_save) + + error_msg = ( + f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})" + ) + logger.error(error_msg) + return False, error_msg + + model_data["from_civitai"] = True + model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db" + model_data["db_checked"] = enable_archive + model_data["last_checked_at"] = datetime.now().timestamp() + + local_metadata = model_data.copy() + local_metadata.pop("folder", None) + + await self.update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + metadata_provider, + ) + + update_payload = { + "model_name": local_metadata.get("model_name"), + "preview_url": local_metadata.get("preview_url"), + "civitai": local_metadata.get("civitai"), + } + model_data.update(update_payload) + + await update_cache_func(file_path, file_path, local_metadata) + return True, None + except KeyError as exc: + error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}" + logger.error(error_msg) + return False, error_msg + except Exception as exc: # pragma: no cover - error path + error_msg = f"Error fetching metadata: {exc}" + logger.error(error_msg, exc_info=True) + return False, error_msg + + async def fetch_metadata_by_sha( + self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None + ) -> tuple[Optional[Dict[str, Any]], Optional[str]]: + """Fetch metadata for a SHA256 hash from the configured provider.""" + + provider = metadata_provider or await self._get_default_provider() + return await provider.get_model_by_hash(sha256) + + async def relink_metadata( + self, + *, + file_path: str, + metadata: Dict[str, Any], + model_id: int, + model_version_id: Optional[int], + ) -> Dict[str, Any]: + """Relink a local metadata record to a specific CivitAI model version.""" + + provider = await self._get_default_provider() + civitai_metadata = await provider.get_model_version(model_id, model_version_id) + if not civitai_metadata: + raise ValueError( + f"Model version not found on CivitAI for ID: {model_id}" + + (f" with version: {model_version_id}" if model_version_id else "") + ) + + primary_model_file: Optional[Dict[str, Any]] = None + for file_info in civitai_metadata.get("files", []): + if file_info.get("primary", False) and file_info.get("type") == "Model": + primary_model_file = file_info + break + + if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"): + metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower() + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + await self.update_model_metadata( + metadata_path, + metadata, + civitai_metadata, + provider, + ) + + return metadata + + async def save_metadata_updates( + self, + *, + file_path: str, + updates: Dict[str, Any], + metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]], + update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> Dict[str, Any]: + """Apply metadata updates and persist to disk and cache.""" + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + + for key, value in updates.items(): + if isinstance(value, dict) and isinstance(metadata.get(key), dict): + metadata[key].update(value) + else: + metadata[key] = value + + await self._metadata_manager.save_metadata(file_path, metadata) + await update_cache(file_path, file_path, metadata) + + if "model_name" in updates: + logger.debug("Metadata update touched model_name; cache resort required") + + return metadata + + async def verify_duplicate_hashes( + self, + *, + file_paths: Iterable[str], + metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]], + hash_calculator: Callable[[str], Awaitable[str]], + update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> Dict[str, Any]: + """Verify a collection of files share the same SHA256 hash.""" + + file_paths = list(file_paths) + if not file_paths: + raise ValueError("No file paths provided for verification") + + results = { + "verified_as_duplicates": True, + "mismatched_files": [], + "new_hash_map": {}, + } + + expected_hash: Optional[str] = None + first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json" + first_metadata = await metadata_loader(first_metadata_path) + if first_metadata and "sha256" in first_metadata: + expected_hash = first_metadata["sha256"].lower() + + for path in file_paths: + if not os.path.exists(path): + continue + + try: + actual_hash = await hash_calculator(path) + metadata_path = os.path.splitext(path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + stored_hash = metadata.get("sha256", "").lower() + + if not expected_hash: + expected_hash = stored_hash + + if actual_hash != expected_hash: + results["verified_as_duplicates"] = False + results["mismatched_files"].append(path) + results["new_hash_map"][path] = actual_hash + + if actual_hash != stored_hash: + metadata["sha256"] = actual_hash + await self._metadata_manager.save_metadata(path, metadata) + await update_cache(path, path, metadata) + except Exception as exc: # pragma: no cover - defensive path + logger.error("Error verifying hash for %s: %s", path, exc) + results["mismatched_files"].append(path) + results["new_hash_map"][path] = "error_calculating_hash" + results["verified_as_duplicates"] = False + + return results + diff --git a/py/services/preview_asset_service.py b/py/services/preview_asset_service.py new file mode 100644 index 00000000..42baadac --- /dev/null +++ b/py/services/preview_asset_service.py @@ -0,0 +1,168 @@ +"""Service for processing preview assets for models.""" + +from __future__ import annotations + +import logging +import os +from typing import Awaitable, Callable, Dict, Optional, Sequence + +from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS + +logger = logging.getLogger(__name__) + + +class PreviewAssetService: + """Manage fetching and persisting preview assets.""" + + def __init__( + self, + *, + metadata_manager, + downloader_factory: Callable[[], Awaitable], + exif_utils, + ) -> None: + self._metadata_manager = metadata_manager + self._downloader_factory = downloader_factory + self._exif_utils = exif_utils + + async def ensure_preview_for_metadata( + self, + metadata_path: str, + local_metadata: Dict[str, object], + images: Sequence[Dict[str, object]] | None, + ) -> None: + """Ensure preview assets exist for the supplied metadata entry.""" + + if local_metadata.get("preview_url") and os.path.exists( + str(local_metadata["preview_url"]) + ): + return + + if not images: + return + + first_preview = images[0] + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_dir = os.path.dirname(metadata_path) + is_video = first_preview.get("type") == "video" + + if is_video: + extension = ".mp4" + preview_path = os.path.join(preview_dir, base_name + extension) + downloader = await self._downloader_factory() + success, result = await downloader.download_file( + first_preview["url"], preview_path, use_auth=False + ) + if success: + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + else: + extension = ".webp" + preview_path = os.path.join(preview_dir, base_name + extension) + downloader = await self._downloader_factory() + success, content, _headers = await downloader.download_to_memory( + first_preview["url"], use_auth=False + ) + if not success: + return + + try: + optimized_data, _ = self._exif_utils.optimize_image( + image_data=content, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=False, + ) + with open(preview_path, "wb") as handle: + handle.write(optimized_data) + except Exception as exc: # pragma: no cover - defensive path + logger.error("Error optimizing preview image: %s", exc) + try: + with open(preview_path, "wb") as handle: + handle.write(content) + except Exception as save_exc: + logger.error("Error saving preview image: %s", save_exc) + return + + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + + async def replace_preview( + self, + *, + model_path: str, + preview_data: bytes, + content_type: str, + original_filename: Optional[str], + nsfw_level: int, + update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]], + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + ) -> Dict[str, object]: + """Replace an existing preview asset for a model.""" + + base_name = os.path.splitext(os.path.basename(model_path))[0] + folder = os.path.dirname(model_path) + + extension, optimized_data = await self._convert_preview( + preview_data, content_type, original_filename + ) + + for ext in PREVIEW_EXTENSIONS: + existing_preview = os.path.join(folder, base_name + ext) + if os.path.exists(existing_preview): + try: + os.remove(existing_preview) + except Exception as exc: # pragma: no cover - defensive path + logger.warning( + "Failed to delete existing preview %s: %s", existing_preview, exc + ) + + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/") + with open(preview_path, "wb") as handle: + handle.write(optimized_data) + + metadata_path = os.path.splitext(model_path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + metadata["preview_url"] = preview_path + metadata["preview_nsfw_level"] = nsfw_level + await self._metadata_manager.save_metadata(model_path, metadata) + + await update_preview_in_cache(model_path, preview_path, nsfw_level) + + return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level} + + async def _convert_preview( + self, data: bytes, content_type: str, original_filename: Optional[str] + ) -> tuple[str, bytes]: + """Convert preview bytes to the persisted representation.""" + + if content_type.startswith("video/"): + extension = self._resolve_video_extension(content_type, original_filename) + return extension, data + + original_ext = (original_filename or "").lower() + if original_ext.endswith(".gif") or content_type.lower() == "image/gif": + return ".gif", data + + optimized_data, _ = self._exif_utils.optimize_image( + image_data=data, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=False, + ) + return ".webp", optimized_data + + def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str: + """Infer the best extension for a video preview.""" + + if original_filename: + extension = os.path.splitext(original_filename)[1].lower() + if extension in {".mp4", ".webm", ".mov", ".avi"}: + return extension + + if "webm" in content_type: + return ".webm" + return ".mp4" + diff --git a/py/services/tag_update_service.py b/py/services/tag_update_service.py new file mode 100644 index 00000000..d560e7d6 --- /dev/null +++ b/py/services/tag_update_service.py @@ -0,0 +1,47 @@ +"""Service for updating tag collections on metadata records.""" + +from __future__ import annotations + +import os + +from typing import Awaitable, Callable, Dict, List, Sequence + + +class TagUpdateService: + """Encapsulate tag manipulation for models.""" + + def __init__(self, *, metadata_manager) -> None: + self._metadata_manager = metadata_manager + + async def add_tags( + self, + *, + file_path: str, + new_tags: Sequence[str], + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]], + ) -> List[str]: + """Add tags to a metadata entry while keeping case-insensitive uniqueness.""" + + base, _ = os.path.splitext(file_path) + metadata_path = f"{base}.metadata.json" + metadata = await metadata_loader(metadata_path) + + existing_tags = list(metadata.get("tags", [])) + existing_lower = [tag.lower() for tag in existing_tags] + + tags_added: List[str] = [] + for tag in new_tags: + if isinstance(tag, str) and tag.strip(): + normalized = tag.strip() + if normalized.lower() not in existing_lower: + existing_tags.append(normalized) + existing_lower.append(normalized.lower()) + tags_added.append(normalized) + + metadata["tags"] = existing_tags + await self._metadata_manager.save_metadata(file_path, metadata) + await update_cache(file_path, file_path, metadata) + + return existing_tags + diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 71566bff..66db05a3 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -1,14 +1,34 @@ import logging import os import re -from ..utils.metadata_manager import MetadataManager -from ..utils.routes_common import ModelRouteUtils + +from ..recipes.constants import GEN_PARAM_KEYS +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider +from ..services.metadata_sync_service import MetadataSyncService +from ..services.preview_asset_service import PreviewAssetService +from ..services.settings_manager import settings +from ..services.downloader import get_downloader from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..utils.exif_utils import ExifUtils -from ..recipes.constants import GEN_PARAM_KEYS +from ..utils.metadata_manager import MetadataManager logger = logging.getLogger(__name__) +_preview_service = PreviewAssetService( + metadata_manager=MetadataManager, + downloader_factory=get_downloader, + exif_utils=ExifUtils, +) + +_metadata_sync_service = MetadataSyncService( + metadata_manager=MetadataManager, + preview_service=_preview_service, + settings=settings, + default_metadata_provider_factory=get_default_metadata_provider, + metadata_provider_selector=get_metadata_provider, +) + + class MetadataUpdater: """Handles updating model metadata related to example images""" @@ -53,11 +73,11 @@ class MetadataUpdater: async def update_cache_func(old_path, new_path, metadata): return await scanner.update_single_model_cache(old_path, new_path, metadata) - success, error = await ModelRouteUtils.fetch_and_update_model( - model_hash, - file_path, - model_data, - update_cache_func + success, error = await _metadata_sync_service.fetch_and_update_model( + sha256=model_hash, + file_path=file_path, + model_data=model_data, + update_cache_func=update_cache_func, ) if success: diff --git a/pytest.ini b/pytest.ini index 44f4dc04..6f82885c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,5 +4,8 @@ testpaths = tests python_files = test_*.py python_classes = Test* python_functions = test_* +# Register async marker for coroutine-style tests +markers = + asyncio: execute test within asyncio event loop # Skip problematic directories to avoid import conflicts norecursedirs = .git .tox dist build *.egg __pycache__ py \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index dfe99691..818aeb9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import types from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence +import asyncio +import inspect from unittest import mock import sys @@ -39,6 +41,13 @@ nodes_mock.NODE_CLASS_MAPPINGS = {} sys.modules['nodes'] = nodes_mock +def pytest_pyfunc_call(pyfuncitem): + if inspect.iscoroutinefunction(pyfuncitem.function): + asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs)) + return True + return None + + @dataclass class MockHashIndex: """Minimal hash index stub mirroring the scanner contract.""" diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 4acfcc49..c3fdc884 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -1,8 +1,35 @@ import pytest -from py.services.base_model_service import BaseModelService -from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams -from py.utils.models import BaseModelMetadata +import importlib +import importlib.util +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def import_from(module_name: str): + existing = sys.modules.get("py") + if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): + sys.modules.pop("py", None) + spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py") + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) # type: ignore[union-attr] + module.__path__ = [str(ROOT / "py")] + sys.modules["py"] = module + return importlib.import_module(module_name) + + +BaseModelService = import_from("py.services.base_model_service").BaseModelService +model_query_module = import_from("py.services.model_query") +ModelCacheRepository = model_query_module.ModelCacheRepository +ModelFilterSet = model_query_module.ModelFilterSet +SearchStrategy = model_query_module.SearchStrategy +SortParams = model_query_module.SortParams +BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata class StubSettings: diff --git a/tests/services/test_route_support_services.py b/tests/services/test_route_support_services.py new file mode 100644 index 00000000..978438c3 --- /dev/null +++ b/tests/services/test_route_support_services.py @@ -0,0 +1,273 @@ +import asyncio +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict, List + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +import importlib +import importlib.util + +import pytest + + +def import_from(module_name: str): + existing = sys.modules.get("py") + if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): + sys.modules.pop("py", None) + spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py") + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) # type: ignore[union-attr] + module.__path__ = [str(ROOT / "py")] + sys.modules["py"] = module + return importlib.import_module(module_name) + + +DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator +MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService +PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService +TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService + + +class DummySettings: + def __init__(self, values: Dict[str, Any] | None = None) -> None: + self._values = values or {} + + def get(self, key: str, default: Any = None) -> Any: + return self._values.get(key, default) + + +class RecordingMetadataManager: + def __init__(self) -> None: + self.saved: List[tuple[str, Dict[str, Any]]] = [] + + async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool: + self.saved.append((path, json.loads(json.dumps(metadata)))) + metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json" + Path(metadata_path).write_text(json.dumps(metadata)) + return True + + +class RecordingPreviewService: + def __init__(self) -> None: + self.calls: List[tuple[str, List[Dict[str, Any]]]] = [] + + async def ensure_preview_for_metadata( + self, metadata_path: str, local_metadata: Dict[str, Any], images + ) -> None: + self.calls.append((metadata_path, list(images or []))) + local_metadata["preview_url"] = "preview.webp" + local_metadata["preview_nsfw_level"] = 1 + + +class DummyProvider: + def __init__(self, payload: Dict[str, Any]) -> None: + self.payload = payload + + async def get_model_by_hash(self, sha256: str): + return self.payload, None + + async def get_model_version(self, model_id: int, model_version_id: int | None): + return self.payload + + +class FakeExifUtils: + @staticmethod + def optimize_image(**kwargs): + return kwargs["image_data"], {} + + +def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None: + manager = RecordingMetadataManager() + preview = RecordingPreviewService() + provider = DummyProvider({ + "baseModel": "SD15", + "model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}}, + "trainedWords": ["word"], + "images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}], + }) + + service = MetadataSyncService( + metadata_manager=manager, + preview_service=preview, + settings=DummySettings(), + default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider), + metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider), + ) + + metadata_path = str(tmp_path / "model.metadata.json") + local_metadata = {"civitai": {"trainedWords": ["existing"]}} + + updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload)) + + assert updated["model_name"] == "Merged" + assert updated["modelDescription"] == "desc" + assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"} + assert manager.saved + assert preview.calls + + +def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None: + manager = RecordingMetadataManager() + preview = RecordingPreviewService() + provider = DummyProvider({ + "baseModel": "SDXL", + "model": {"name": "Updated"}, + "images": [], + }) + + update_cache_calls: List[Dict[str, Any]] = [] + + async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool: + update_cache_calls.append({"original": original, "metadata": metadata}) + return True + + service = MetadataSyncService( + metadata_manager=manager, + preview_service=preview, + settings=DummySettings(), + default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider), + metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider), + ) + + model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")} + success, error = asyncio.run( + service.fetch_and_update_model( + sha256="abc", + file_path=str(tmp_path / "model.safetensors"), + model_data=model_data, + update_cache_func=update_cache, + ) + ) + + assert success is True + assert error is None + assert update_cache_calls + assert manager.saved + + +def test_preview_asset_service_replace_preview(tmp_path: Path) -> None: + metadata_path = tmp_path / "sample.metadata.json" + metadata_path.write_text(json.dumps({})) + + async def metadata_loader(path: str) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + + manager = RecordingMetadataManager() + + service = PreviewAssetService( + metadata_manager=manager, + downloader_factory=lambda: asyncio.sleep(0, result=None), + exif_utils=FakeExifUtils(), + ) + + preview_calls: List[Dict[str, Any]] = [] + + async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool: + preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw}) + return True + + model_path = str(tmp_path / "sample.safetensors") + Path(model_path).write_bytes(b"model") + + result = asyncio.run( + service.replace_preview( + model_path=model_path, + preview_data=b"image-bytes", + content_type="image/png", + original_filename="preview.png", + nsfw_level=2, + update_preview_in_cache=update_preview, + metadata_loader=metadata_loader, + ) + ) + + assert result["preview_nsfw_level"] == 2 + assert preview_calls + saved_metadata = json.loads(metadata_path.read_text()) + assert saved_metadata["preview_nsfw_level"] == 2 + + +def test_download_coordinator_emits_progress() -> None: + class WSStub: + def __init__(self) -> None: + self.progress_events: List[Dict[str, Any]] = [] + self.counter = 0 + + def generate_download_id(self) -> str: + self.counter += 1 + return f"dl-{self.counter}" + + async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None: + self.progress_events.append({"id": download_id, **payload}) + + class DownloadManagerStub: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def download_from_civitai(self, **kwargs) -> Dict[str, Any]: + self.calls.append(kwargs) + await kwargs["progress_callback"](10) + return {"success": True} + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + return {"success": True, "download_id": download_id} + + async def get_active_downloads(self) -> Dict[str, Any]: + return {"active": []} + + ws_stub = WSStub() + manager_stub = DownloadManagerStub() + + coordinator = DownloadCoordinator( + ws_manager=ws_stub, + download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub), + ) + + result = asyncio.run(coordinator.schedule_download({"model_id": 1})) + + assert result["success"] is True + assert manager_stub.calls + assert ws_stub.progress_events + + cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"])) + assert cancel_result["success"] is True + + active = asyncio.run(coordinator.list_active_downloads()) + assert active == {"active": []} + + +def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None: + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text(json.dumps({"tags": ["Existing"]})) + + async def loader(path: str) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + + manager = RecordingMetadataManager() + + service = TagUpdateService(metadata_manager=manager) + + cache_updates: List[Dict[str, Any]] = [] + + async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool: + cache_updates.append(metadata) + return True + + tags = asyncio.run( + service.add_tags( + file_path=str(tmp_path / "model.safetensors"), + new_tags=["New", "existing"], + metadata_loader=loader, + update_cache=update_cache, + ) + ) + + assert tags == ["Existing", "New"] + assert manager.saved + assert cache_updates