Merge pull request #451 from willmiao/codex/refactor-routes_common.py-into-services

Refactor model route utilities into dedicated services
This commit is contained in:
pixelpaws
2025-09-21 23:38:38 +08:00
committed by GitHub
11 changed files with 1269 additions and 30 deletions

View File

@@ -8,12 +8,20 @@ import jinja2
from aiohttp import web
from ..config import config
from ..services.metadata_service import get_default_metadata_provider
from ..services.download_coordinator import DownloadCoordinator
from ..services.downloader import get_downloader
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.model_file_service import ModelFileService, ModelMoveService
from ..services.preview_asset_service import PreviewAssetService
from ..services.server_i18n import server_i18n as default_server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings
from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.websocket_progress_callback import WebSocketProgressCallback
from ..services.server_i18n import server_i18n as default_server_i18n
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from ..utils.routes_common import ModelRouteUtils
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
from .handlers.model_handlers import (
@@ -64,6 +72,24 @@ class BaseModelRoutes(ABC):
self._handler_set: ModelHandlerSet | None = None
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
self._preview_service = PreviewAssetService(
metadata_manager=MetadataManager,
downloader_factory=get_downloader,
exif_utils=ExifUtils,
)
self._metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=self._preview_service,
settings=settings_service,
default_metadata_provider_factory=metadata_provider_factory,
metadata_provider_selector=get_metadata_provider,
)
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
self._download_coordinator = DownloadCoordinator(
ws_manager=self._ws_manager,
download_manager_factory=ServiceRegistry.get_download_manager,
)
if service is not None:
self.attach_service(service)
@@ -98,9 +124,19 @@ class BaseModelRoutes(ABC):
parse_specific_params=self._parse_specific_params,
logger=logger,
)
management = ModelManagementHandler(service=service, logger=logger)
management = ModelManagementHandler(
service=service,
logger=logger,
metadata_sync=self._metadata_sync_service,
preview_service=self._preview_service,
tag_update_service=self._tag_update_service,
)
query = ModelQueryHandler(service=service, logger=logger)
download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger)
download = ModelDownloadHandler(
ws_manager=self._ws_manager,
logger=logger,
download_coordinator=self._download_coordinator,
)
civitai = ModelCivitaiHandler(
service=service,
settings_service=self._settings,
@@ -110,6 +146,7 @@ class BaseModelRoutes(ABC):
validate_model_type=self._validate_civitai_model_type,
expected_model_types=self._get_expected_model_types,
find_model_file=self._find_model_file,
metadata_sync=self._metadata_sync_service,
)
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
auto_organize = ModelAutoOrganizeHandler(

View File

@@ -4,16 +4,23 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
from aiohttp import web
import jinja2
from ...config import config
from ...services.download_coordinator import DownloadCoordinator
from ...services.metadata_sync_service import MetadataSyncService
from ...services.model_file_service import ModelFileService, ModelMoveService
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.websocket_manager import WebSocketManager
from ...services.preview_asset_service import PreviewAssetService
from ...services.settings_manager import SettingsManager
from ...services.tag_update_service import TagUpdateService
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256
from ...utils.routes_common import ModelRouteUtils
@@ -168,9 +175,20 @@ class ModelListingHandler:
class ModelManagementHandler:
"""Handle mutation operations on models."""
def __init__(self, *, service, logger: logging.Logger) -> None:
def __init__(
self,
*,
service,
logger: logging.Logger,
metadata_sync: MetadataSyncService,
preview_service: PreviewAssetService,
tag_update_service: TagUpdateService,
) -> None:
self._service = service
self._logger = logger
self._metadata_sync = metadata_sync
self._preview_service = preview_service
self._tag_update_service = tag_update_service
async def delete_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
@@ -192,7 +210,7 @@ class ModelManagementHandler:
if not model_data.get("sha256"):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
success, error = await ModelRouteUtils.fetch_and_update_model(
success, error = await self._metadata_sync.fetch_and_update_model(
sha256=model_data["sha256"],
file_path=file_path,
model_data=model_data,
@@ -208,16 +226,144 @@ class ModelManagementHandler:
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def relink_civitai(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
model_id = data.get("model_id")
model_version_id = data.get("model_version_id")
if not file_path or model_id is None:
return web.json_response(
{"success": False, "error": "Both file_path and model_id are required"},
status=400,
)
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
updated_metadata = await self._metadata_sync.relink_metadata(
file_path=file_path,
metadata=local_metadata,
model_id=int(model_id),
model_version_id=int(model_version_id) if model_version_id else None,
)
await self._service.scanner.update_single_model_cache(
file_path, file_path, updated_metadata
)
message = (
f"Model successfully re-linked to Civitai model {model_id}"
+ (f" version {model_version_id}" if model_version_id else "")
)
return web.json_response(
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
)
except Exception as exc:
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def replace_preview(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
try:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "preview_file":
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get("Content-Type", "image/png")
content_disposition = field.headers.get("Content-Disposition", "")
original_filename = None
import re
match = re.search(r'filename="(.*?)"', content_disposition)
if match:
original_filename = match.group(1)
preview_data = await field.read()
field = await reader.next()
if field is None or field.name != "model_path":
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
nsfw_level = 0
field = await reader.next()
if field and field.name == "nsfw_level":
try:
nsfw_level = int((await field.read()).decode())
except (ValueError, TypeError):
self._logger.warning("Invalid NSFW level format, using default 0")
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_data,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(result["preview_path"]),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.Response(text="File path is required", status=400)
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
await self._metadata_sync.save_metadata_updates(
file_path=file_path,
updates=metadata_updates,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
if "model_name" in metadata_updates:
cache = await self._service.scanner.get_cached_data()
await cache.resort()
return web.json_response({"success": True})
except Exception as exc:
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def add_tags(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
try:
data = await request.json()
file_path = data.get("file_path")
new_tags = data.get("tags", [])
if not file_path:
return web.Response(text="File path is required", status=400)
if not isinstance(new_tags, list):
return web.Response(text="Tags must be a list", status=400)
tags = await self._tag_update_service.add_tags(
file_path=file_path,
new_tags=new_tags,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, "tags": tags})
except Exception as exc:
self._logger.error("Error adding tags: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def rename_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
@@ -226,7 +372,27 @@ class ModelManagementHandler:
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
try:
data = await request.json()
file_paths = data.get("file_paths", [])
if not file_paths:
return web.json_response(
{"success": False, "error": "No file paths provided for verification"},
status=400,
)
results = await self._metadata_sync.verify_duplicate_hashes(
file_paths=file_paths,
metadata_loader=self._metadata_sync.load_local_metadata,
hash_calculator=calculate_sha256,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, **results})
except Exception as exc:
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelQueryHandler:
@@ -429,12 +595,39 @@ class ModelQueryHandler:
class ModelDownloadHandler:
"""Coordinate downloads and progress reporting."""
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
def __init__(
self,
*,
ws_manager: WebSocketManager,
logger: logging.Logger,
download_coordinator: DownloadCoordinator,
) -> None:
self._ws_manager = ws_manager
self._logger = logger
self._download_coordinator = download_coordinator
async def download_model(self, request: web.Request) -> web.Response:
return await ModelRouteUtils.handle_download_model(request)
try:
payload = await request.json()
result = await self._download_coordinator.schedule_download(payload)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
error_message = str(exc)
if "401" in error_message:
self._logger.warning("Early access error (401): %s", error_message)
return web.json_response(
{
"success": False,
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
},
status=401,
)
self._logger.error("Error downloading model: %s", error_message)
return web.json_response({"success": False, "error": error_message}, status=500)
async def download_model_get(self, request: web.Request) -> web.Response:
try:
@@ -460,7 +653,12 @@ class ModelDownloadHandler:
future.set_result(data)
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
return await ModelRouteUtils.handle_download_model(mock_request)
result = await self._download_coordinator.schedule_download(data)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
return web.Response(status=500, text=str(exc))
@@ -470,8 +668,8 @@ class ModelDownloadHandler:
download_id = request.query.get("download_id")
if not download_id:
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
return await ModelRouteUtils.handle_cancel_download(mock_request)
result = await self._download_coordinator.cancel_download(download_id)
return web.json_response(result)
except Exception as exc:
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@@ -504,6 +702,7 @@ class ModelCivitaiHandler:
validate_model_type: Callable[[str], bool],
expected_model_types: Callable[[], str],
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
metadata_sync: MetadataSyncService,
) -> None:
self._service = service
self._settings = settings_service
@@ -513,6 +712,7 @@ class ModelCivitaiHandler:
self._validate_model_type = validate_model_type
self._expected_model_types = expected_model_types
self._find_model_file = find_model_file
self._metadata_sync = metadata_sync
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
try:
@@ -545,7 +745,7 @@ class ModelCivitaiHandler:
for model in to_process:
try:
original_name = model.get("model_name")
result, error = await ModelRouteUtils.fetch_and_update_model(
result, error = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,

View File

@@ -0,0 +1,100 @@
"""Service wrapper for coordinating download lifecycle events."""
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable, Dict, Optional
logger = logging.getLogger(__name__)
class DownloadCoordinator:
"""Manage download scheduling, cancellation and introspection."""
def __init__(
self,
*,
ws_manager,
download_manager_factory: Callable[[], Awaitable],
) -> None:
self._ws_manager = ws_manager
self._download_manager_factory = download_manager_factory
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Schedule a download using the provided payload."""
download_manager = await self._download_manager_factory()
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
payload.setdefault("download_id", download_id)
async def progress_callback(progress: Any) -> None:
await self._ws_manager.broadcast_download_progress(
download_id,
{
"status": "progress",
"progress": progress,
"download_id": download_id,
},
)
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
model_version_id = self._parse_optional_int(
payload.get("model_version_id"), "model_version_id"
)
if model_id is None and model_version_id is None:
raise ValueError(
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
)
result = await download_manager.download_from_civitai(
model_id=model_id,
model_version_id=model_version_id,
save_dir=payload.get("model_root"),
relative_path=payload.get("relative_path", ""),
use_default_paths=payload.get("use_default_paths", False),
progress_callback=progress_callback,
download_id=download_id,
source=payload.get("source"),
)
result["download_id"] = download_id
return result
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
"""Cancel an active download and emit a broadcast event."""
download_manager = await self._download_manager_factory()
result = await download_manager.cancel_download(download_id)
await self._ws_manager.broadcast_download_progress(
download_id,
{
"status": "cancelled",
"progress": 0,
"download_id": download_id,
"message": "Download cancelled by user",
},
)
return result
async def list_active_downloads(self) -> Dict[str, Any]:
"""Return the active download map from the underlying manager."""
download_manager = await self._download_manager_factory()
return await download_manager.get_active_downloads()
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
"""Parse an optional integer from user input."""
if value is None or value == "":
return None
try:
return int(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid {field}: Must be an integer") from exc

View File

@@ -0,0 +1,355 @@
"""Services for synchronising metadata with remote providers."""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
from ..services.settings_manager import SettingsManager
from ..utils.model_utils import determine_base_model
logger = logging.getLogger(__name__)
class MetadataProviderProtocol:
"""Subset of metadata provider interface consumed by the sync service."""
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
...
async def get_model_version(
self, model_id: int, model_version_id: Optional[int]
) -> Optional[Dict[str, Any]]:
...
class MetadataSyncService:
"""High level orchestration for metadata synchronisation flows."""
def __init__(
self,
*,
metadata_manager,
preview_service,
settings: SettingsManager,
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
) -> None:
self._metadata_manager = metadata_manager
self._preview_service = preview_service
self._settings = settings
self._get_default_provider = default_metadata_provider_factory
self._get_provider = metadata_provider_selector
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
"""Load metadata JSON from disk, returning an empty structure when missing."""
if not os.path.exists(metadata_path):
return {}
try:
with open(metadata_path, "r", encoding="utf-8") as handle:
return json.load(handle)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
return {}
async def mark_not_found_on_civitai(
self, metadata_path: str, local_metadata: Dict[str, Any]
) -> None:
"""Persist the not-found flag for a metadata payload."""
local_metadata["from_civitai"] = False
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
@staticmethod
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
"""Determine if the metadata originated from the CivitAI public API."""
if not isinstance(meta, dict):
return False
files = meta.get("files")
images = meta.get("images")
source = meta.get("source")
return bool(files) and bool(images) and source != "archive_db"
async def update_model_metadata(
self,
metadata_path: str,
local_metadata: Dict[str, Any],
civitai_metadata: Dict[str, Any],
metadata_provider: Optional[MetadataProviderProtocol] = None,
) -> Dict[str, Any]:
"""Merge remote metadata into the local record and persist the result."""
existing_civitai = local_metadata.get("civitai") or {}
if (
civitai_metadata.get("source") == "archive_db"
and self.is_civitai_api_metadata(existing_civitai)
):
logger.info(
"Skip civitai update for %s (%s)",
local_metadata.get("model_name", ""),
existing_civitai.get("name", ""),
)
else:
merged_civitai = existing_civitai.copy()
merged_civitai.update(civitai_metadata)
if civitai_metadata.get("source") == "archive_db":
model_name = civitai_metadata.get("model", {}).get("name", "")
version_name = civitai_metadata.get("name", "")
logger.info(
"Recovered metadata from archive_db for deleted model: %s (%s)",
model_name,
version_name,
)
if "trainedWords" in existing_civitai:
existing_trained = existing_civitai.get("trainedWords", [])
new_trained = civitai_metadata.get("trainedWords", [])
merged_trained = list(set(existing_trained + new_trained))
merged_civitai["trainedWords"] = merged_trained
local_metadata["civitai"] = merged_civitai
if "model" in civitai_metadata and civitai_metadata["model"]:
model_data = civitai_metadata["model"]
if model_data.get("name"):
local_metadata["model_name"] = model_data["name"]
if not local_metadata.get("modelDescription") and model_data.get("description"):
local_metadata["modelDescription"] = model_data["description"]
if not local_metadata.get("tags") and model_data.get("tags"):
local_metadata["tags"] = model_data["tags"]
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
"creator"
):
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
local_metadata["base_model"] = determine_base_model(
civitai_metadata.get("baseModel")
)
await self._preview_service.ensure_preview_for_metadata(
metadata_path, local_metadata, civitai_metadata.get("images", [])
)
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
return local_metadata
async def fetch_and_update_model(
self,
*,
sha256: str,
file_path: str,
model_data: Dict[str, Any],
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
) -> tuple[bool, Optional[str]]:
"""Fetch metadata for a model and update both disk and cache state."""
if not isinstance(model_data, dict):
error = f"Invalid model_data type: {type(model_data)}"
logger.error(error)
return False, error
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
enable_archive = self._settings.get("enable_metadata_archive_db", False)
try:
if model_data.get("civitai_deleted") is True:
if not enable_archive or model_data.get("db_checked") is True:
return (
False,
"CivitAI model is deleted and metadata archive DB is not enabled",
)
metadata_provider = await self._get_provider("sqlite")
else:
metadata_provider = await self._get_default_provider()
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
if not civitai_metadata:
if error == "Model not found":
model_data["from_civitai"] = False
model_data["civitai_deleted"] = True
model_data["db_checked"] = enable_archive
model_data["last_checked_at"] = datetime.now().timestamp()
data_to_save = model_data.copy()
data_to_save.pop("folder", None)
await self._metadata_manager.save_metadata(file_path, data_to_save)
error_msg = (
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
)
logger.error(error_msg)
return False, error_msg
model_data["from_civitai"] = True
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
model_data["db_checked"] = enable_archive
model_data["last_checked_at"] = datetime.now().timestamp()
local_metadata = model_data.copy()
local_metadata.pop("folder", None)
await self.update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
metadata_provider,
)
update_payload = {
"model_name": local_metadata.get("model_name"),
"preview_url": local_metadata.get("preview_url"),
"civitai": local_metadata.get("civitai"),
}
model_data.update(update_payload)
await update_cache_func(file_path, file_path, local_metadata)
return True, None
except KeyError as exc:
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
logger.error(error_msg)
return False, error_msg
except Exception as exc: # pragma: no cover - error path
error_msg = f"Error fetching metadata: {exc}"
logger.error(error_msg, exc_info=True)
return False, error_msg
async def fetch_metadata_by_sha(
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Fetch metadata for a SHA256 hash from the configured provider."""
provider = metadata_provider or await self._get_default_provider()
return await provider.get_model_by_hash(sha256)
async def relink_metadata(
self,
*,
file_path: str,
metadata: Dict[str, Any],
model_id: int,
model_version_id: Optional[int],
) -> Dict[str, Any]:
"""Relink a local metadata record to a specific CivitAI model version."""
provider = await self._get_default_provider()
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
if not civitai_metadata:
raise ValueError(
f"Model version not found on CivitAI for ID: {model_id}"
+ (f" with version: {model_version_id}" if model_version_id else "")
)
primary_model_file: Optional[Dict[str, Any]] = None
for file_info in civitai_metadata.get("files", []):
if file_info.get("primary", False) and file_info.get("type") == "Model":
primary_model_file = file_info
break
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
await self.update_model_metadata(
metadata_path,
metadata,
civitai_metadata,
provider,
)
return metadata
async def save_metadata_updates(
self,
*,
file_path: str,
updates: Dict[str, Any],
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
) -> Dict[str, Any]:
"""Apply metadata updates and persist to disk and cache."""
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
metadata = await metadata_loader(metadata_path)
for key, value in updates.items():
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
metadata[key].update(value)
else:
metadata[key] = value
await self._metadata_manager.save_metadata(file_path, metadata)
await update_cache(file_path, file_path, metadata)
if "model_name" in updates:
logger.debug("Metadata update touched model_name; cache resort required")
return metadata
async def verify_duplicate_hashes(
self,
*,
file_paths: Iterable[str],
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
hash_calculator: Callable[[str], Awaitable[str]],
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
) -> Dict[str, Any]:
"""Verify a collection of files share the same SHA256 hash."""
file_paths = list(file_paths)
if not file_paths:
raise ValueError("No file paths provided for verification")
results = {
"verified_as_duplicates": True,
"mismatched_files": [],
"new_hash_map": {},
}
expected_hash: Optional[str] = None
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
first_metadata = await metadata_loader(first_metadata_path)
if first_metadata and "sha256" in first_metadata:
expected_hash = first_metadata["sha256"].lower()
for path in file_paths:
if not os.path.exists(path):
continue
try:
actual_hash = await hash_calculator(path)
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
metadata = await metadata_loader(metadata_path)
stored_hash = metadata.get("sha256", "").lower()
if not expected_hash:
expected_hash = stored_hash
if actual_hash != expected_hash:
results["verified_as_duplicates"] = False
results["mismatched_files"].append(path)
results["new_hash_map"][path] = actual_hash
if actual_hash != stored_hash:
metadata["sha256"] = actual_hash
await self._metadata_manager.save_metadata(path, metadata)
await update_cache(path, path, metadata)
except Exception as exc: # pragma: no cover - defensive path
logger.error("Error verifying hash for %s: %s", path, exc)
results["mismatched_files"].append(path)
results["new_hash_map"][path] = "error_calculating_hash"
results["verified_as_duplicates"] = False
return results

View File

@@ -0,0 +1,168 @@
"""Service for processing preview assets for models."""
from __future__ import annotations
import logging
import os
from typing import Awaitable, Callable, Dict, Optional, Sequence
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
class PreviewAssetService:
"""Manage fetching and persisting preview assets."""
def __init__(
self,
*,
metadata_manager,
downloader_factory: Callable[[], Awaitable],
exif_utils,
) -> None:
self._metadata_manager = metadata_manager
self._downloader_factory = downloader_factory
self._exif_utils = exif_utils
async def ensure_preview_for_metadata(
self,
metadata_path: str,
local_metadata: Dict[str, object],
images: Sequence[Dict[str, object]] | None,
) -> None:
"""Ensure preview assets exist for the supplied metadata entry."""
if local_metadata.get("preview_url") and os.path.exists(
str(local_metadata["preview_url"])
):
return
if not images:
return
first_preview = images[0]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_dir = os.path.dirname(metadata_path)
is_video = first_preview.get("type") == "video"
if is_video:
extension = ".mp4"
preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, result = await downloader.download_file(
first_preview["url"], preview_path, use_auth=False
)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
else:
extension = ".webp"
preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, content, _headers = await downloader.download_to_memory(
first_preview["url"], use_auth=False
)
if not success:
return
try:
optimized_data, _ = self._exif_utils.optimize_image(
image_data=content,
target_width=CARD_PREVIEW_WIDTH,
format="webp",
quality=85,
preserve_metadata=False,
)
with open(preview_path, "wb") as handle:
handle.write(optimized_data)
except Exception as exc: # pragma: no cover - defensive path
logger.error("Error optimizing preview image: %s", exc)
try:
with open(preview_path, "wb") as handle:
handle.write(content)
except Exception as save_exc:
logger.error("Error saving preview image: %s", save_exc)
return
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
async def replace_preview(
self,
*,
model_path: str,
preview_data: bytes,
content_type: str,
original_filename: Optional[str],
nsfw_level: int,
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
) -> Dict[str, object]:
"""Replace an existing preview asset for a model."""
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
extension, optimized_data = await self._convert_preview(
preview_data, content_type, original_filename
)
for ext in PREVIEW_EXTENSIONS:
existing_preview = os.path.join(folder, base_name + ext)
if os.path.exists(existing_preview):
try:
os.remove(existing_preview)
except Exception as exc: # pragma: no cover - defensive path
logger.warning(
"Failed to delete existing preview %s: %s", existing_preview, exc
)
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
with open(preview_path, "wb") as handle:
handle.write(optimized_data)
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
metadata = await metadata_loader(metadata_path)
metadata["preview_url"] = preview_path
metadata["preview_nsfw_level"] = nsfw_level
await self._metadata_manager.save_metadata(model_path, metadata)
await update_preview_in_cache(model_path, preview_path, nsfw_level)
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
async def _convert_preview(
self, data: bytes, content_type: str, original_filename: Optional[str]
) -> tuple[str, bytes]:
"""Convert preview bytes to the persisted representation."""
if content_type.startswith("video/"):
extension = self._resolve_video_extension(content_type, original_filename)
return extension, data
original_ext = (original_filename or "").lower()
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
return ".gif", data
optimized_data, _ = self._exif_utils.optimize_image(
image_data=data,
target_width=CARD_PREVIEW_WIDTH,
format="webp",
quality=85,
preserve_metadata=False,
)
return ".webp", optimized_data
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
"""Infer the best extension for a video preview."""
if original_filename:
extension = os.path.splitext(original_filename)[1].lower()
if extension in {".mp4", ".webm", ".mov", ".avi"}:
return extension
if "webm" in content_type:
return ".webm"
return ".mp4"

View File

@@ -0,0 +1,47 @@
"""Service for updating tag collections on metadata records."""
from __future__ import annotations
import os
from typing import Awaitable, Callable, Dict, List, Sequence
class TagUpdateService:
"""Encapsulate tag manipulation for models."""
def __init__(self, *, metadata_manager) -> None:
self._metadata_manager = metadata_manager
async def add_tags(
self,
*,
file_path: str,
new_tags: Sequence[str],
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
) -> List[str]:
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
base, _ = os.path.splitext(file_path)
metadata_path = f"{base}.metadata.json"
metadata = await metadata_loader(metadata_path)
existing_tags = list(metadata.get("tags", []))
existing_lower = [tag.lower() for tag in existing_tags]
tags_added: List[str] = []
for tag in new_tags:
if isinstance(tag, str) and tag.strip():
normalized = tag.strip()
if normalized.lower() not in existing_lower:
existing_tags.append(normalized)
existing_lower.append(normalized.lower())
tags_added.append(normalized)
metadata["tags"] = existing_tags
await self._metadata_manager.save_metadata(file_path, metadata)
await update_cache(file_path, file_path, metadata)
return existing_tags

View File

@@ -1,14 +1,34 @@
import logging
import os
import re
from ..utils.metadata_manager import MetadataManager
from ..utils.routes_common import ModelRouteUtils
from ..recipes.constants import GEN_PARAM_KEYS
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.preview_asset_service import PreviewAssetService
from ..services.settings_manager import settings
from ..services.downloader import get_downloader
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..utils.exif_utils import ExifUtils
from ..recipes.constants import GEN_PARAM_KEYS
from ..utils.metadata_manager import MetadataManager
logger = logging.getLogger(__name__)
_preview_service = PreviewAssetService(
metadata_manager=MetadataManager,
downloader_factory=get_downloader,
exif_utils=ExifUtils,
)
_metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=_preview_service,
settings=settings,
default_metadata_provider_factory=get_default_metadata_provider,
metadata_provider_selector=get_metadata_provider,
)
class MetadataUpdater:
"""Handles updating model metadata related to example images"""
@@ -53,11 +73,11 @@ class MetadataUpdater:
async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata)
success, error = await ModelRouteUtils.fetch_and_update_model(
model_hash,
file_path,
model_data,
update_cache_func
success, error = await _metadata_sync_service.fetch_and_update_model(
sha256=model_hash,
file_path=file_path,
model_data=model_data,
update_cache_func=update_cache_func,
)
if success:

View File

@@ -4,5 +4,8 @@ testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Register async marker for coroutine-style tests
markers =
asyncio: execute test within asyncio event loop
# Skip problematic directories to avoid import conflicts
norecursedirs = .git .tox dist build *.egg __pycache__ py

View File

@@ -1,6 +1,8 @@
import types
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence
import asyncio
import inspect
from unittest import mock
import sys
@@ -39,6 +41,13 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
sys.modules['nodes'] = nodes_mock
def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem.function):
asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs))
return True
return None
@dataclass
class MockHashIndex:
"""Minimal hash index stub mirroring the scanner contract."""

View File

@@ -1,8 +1,35 @@
import pytest
from py.services.base_model_service import BaseModelService
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
from py.utils.models import BaseModelMetadata
import importlib
import importlib.util
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
def import_from(module_name: str):
existing = sys.modules.get("py")
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
sys.modules.pop("py", None)
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module) # type: ignore[union-attr]
module.__path__ = [str(ROOT / "py")]
sys.modules["py"] = module
return importlib.import_module(module_name)
BaseModelService = import_from("py.services.base_model_service").BaseModelService
model_query_module = import_from("py.services.model_query")
ModelCacheRepository = model_query_module.ModelCacheRepository
ModelFilterSet = model_query_module.ModelFilterSet
SearchStrategy = model_query_module.SearchStrategy
SortParams = model_query_module.SortParams
BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata
class StubSettings:

View File

@@ -0,0 +1,273 @@
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import importlib
import importlib.util
import pytest
def import_from(module_name: str):
existing = sys.modules.get("py")
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
sys.modules.pop("py", None)
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module) # type: ignore[union-attr]
module.__path__ = [str(ROOT / "py")]
sys.modules["py"] = module
return importlib.import_module(module_name)
DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator
MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService
PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService
TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService
class DummySettings:
def __init__(self, values: Dict[str, Any] | None = None) -> None:
self._values = values or {}
def get(self, key: str, default: Any = None) -> Any:
return self._values.get(key, default)
class RecordingMetadataManager:
def __init__(self) -> None:
self.saved: List[tuple[str, Dict[str, Any]]] = []
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
self.saved.append((path, json.loads(json.dumps(metadata))))
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
Path(metadata_path).write_text(json.dumps(metadata))
return True
class RecordingPreviewService:
def __init__(self) -> None:
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
async def ensure_preview_for_metadata(
self, metadata_path: str, local_metadata: Dict[str, Any], images
) -> None:
self.calls.append((metadata_path, list(images or [])))
local_metadata["preview_url"] = "preview.webp"
local_metadata["preview_nsfw_level"] = 1
class DummyProvider:
def __init__(self, payload: Dict[str, Any]) -> None:
self.payload = payload
async def get_model_by_hash(self, sha256: str):
return self.payload, None
async def get_model_version(self, model_id: int, model_version_id: int | None):
return self.payload
class FakeExifUtils:
@staticmethod
def optimize_image(**kwargs):
return kwargs["image_data"], {}
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
manager = RecordingMetadataManager()
preview = RecordingPreviewService()
provider = DummyProvider({
"baseModel": "SD15",
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
"trainedWords": ["word"],
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
})
service = MetadataSyncService(
metadata_manager=manager,
preview_service=preview,
settings=DummySettings(),
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
)
metadata_path = str(tmp_path / "model.metadata.json")
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
assert updated["model_name"] == "Merged"
assert updated["modelDescription"] == "desc"
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
assert manager.saved
assert preview.calls
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
manager = RecordingMetadataManager()
preview = RecordingPreviewService()
provider = DummyProvider({
"baseModel": "SDXL",
"model": {"name": "Updated"},
"images": [],
})
update_cache_calls: List[Dict[str, Any]] = []
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
update_cache_calls.append({"original": original, "metadata": metadata})
return True
service = MetadataSyncService(
metadata_manager=manager,
preview_service=preview,
settings=DummySettings(),
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
)
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
success, error = asyncio.run(
service.fetch_and_update_model(
sha256="abc",
file_path=str(tmp_path / "model.safetensors"),
model_data=model_data,
update_cache_func=update_cache,
)
)
assert success is True
assert error is None
assert update_cache_calls
assert manager.saved
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
metadata_path = tmp_path / "sample.metadata.json"
metadata_path.write_text(json.dumps({}))
async def metadata_loader(path: str) -> Dict[str, Any]:
return json.loads(Path(path).read_text())
manager = RecordingMetadataManager()
service = PreviewAssetService(
metadata_manager=manager,
downloader_factory=lambda: asyncio.sleep(0, result=None),
exif_utils=FakeExifUtils(),
)
preview_calls: List[Dict[str, Any]] = []
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
return True
model_path = str(tmp_path / "sample.safetensors")
Path(model_path).write_bytes(b"model")
result = asyncio.run(
service.replace_preview(
model_path=model_path,
preview_data=b"image-bytes",
content_type="image/png",
original_filename="preview.png",
nsfw_level=2,
update_preview_in_cache=update_preview,
metadata_loader=metadata_loader,
)
)
assert result["preview_nsfw_level"] == 2
assert preview_calls
saved_metadata = json.loads(metadata_path.read_text())
assert saved_metadata["preview_nsfw_level"] == 2
def test_download_coordinator_emits_progress() -> None:
class WSStub:
def __init__(self) -> None:
self.progress_events: List[Dict[str, Any]] = []
self.counter = 0
def generate_download_id(self) -> str:
self.counter += 1
return f"dl-{self.counter}"
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
self.progress_events.append({"id": download_id, **payload})
class DownloadManagerStub:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
self.calls.append(kwargs)
await kwargs["progress_callback"](10)
return {"success": True}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
return {"success": True, "download_id": download_id}
async def get_active_downloads(self) -> Dict[str, Any]:
return {"active": []}
ws_stub = WSStub()
manager_stub = DownloadManagerStub()
coordinator = DownloadCoordinator(
ws_manager=ws_stub,
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
)
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
assert result["success"] is True
assert manager_stub.calls
assert ws_stub.progress_events
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
assert cancel_result["success"] is True
active = asyncio.run(coordinator.list_active_downloads())
assert active == {"active": []}
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
async def loader(path: str) -> Dict[str, Any]:
return json.loads(Path(path).read_text())
manager = RecordingMetadataManager()
service = TagUpdateService(metadata_manager=manager)
cache_updates: List[Dict[str, Any]] = []
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
cache_updates.append(metadata)
return True
tags = asyncio.run(
service.add_tags(
file_path=str(tmp_path / "model.safetensors"),
new_tags=["New", "existing"],
metadata_loader=loader,
update_cache=update_cache,
)
)
assert tags == ["Existing", "New"]
assert manager.saved
assert cache_updates