mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(routes): extract route utilities into services
This commit is contained in:
@@ -8,12 +8,20 @@ import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
from ..services.download_coordinator import DownloadCoordinator
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||
from .handlers.model_handlers import (
|
||||
@@ -64,6 +72,24 @@ class BaseModelRoutes(ABC):
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
self._preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=settings_service,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||
self._download_coordinator = DownloadCoordinator(
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
@@ -98,9 +124,19 @@ class BaseModelRoutes(ABC):
|
||||
parse_specific_params=self._parse_specific_params,
|
||||
logger=logger,
|
||||
)
|
||||
management = ModelManagementHandler(service=service, logger=logger)
|
||||
management = ModelManagementHandler(
|
||||
service=service,
|
||||
logger=logger,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
preview_service=self._preview_service,
|
||||
tag_update_service=self._tag_update_service,
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
@@ -110,6 +146,7 @@ class BaseModelRoutes(ABC):
|
||||
validate_model_type=self._validate_civitai_model_type,
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
|
||||
@@ -4,16 +4,23 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.tag_update_service import TagUpdateService
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.routes_common import ModelRouteUtils
|
||||
|
||||
|
||||
@@ -168,9 +175,20 @@ class ModelListingHandler:
|
||||
class ModelManagementHandler:
|
||||
"""Handle mutation operations on models."""
|
||||
|
||||
def __init__(self, *, service, logger: logging.Logger) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
logger: logging.Logger,
|
||||
metadata_sync: MetadataSyncService,
|
||||
preview_service: PreviewAssetService,
|
||||
tag_update_service: TagUpdateService,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._logger = logger
|
||||
self._metadata_sync = metadata_sync
|
||||
self._preview_service = preview_service
|
||||
self._tag_update_service = tag_update_service
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
|
||||
@@ -192,7 +210,7 @@ class ModelManagementHandler:
|
||||
if not model_data.get("sha256"):
|
||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||
|
||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model_data["sha256"],
|
||||
file_path=file_path,
|
||||
model_data=model_data,
|
||||
@@ -208,16 +226,144 @@ class ModelManagementHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
model_id = data.get("model_id")
|
||||
model_version_id = data.get("model_version_id")
|
||||
|
||||
if not file_path or model_id is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Both file_path and model_id are required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
|
||||
|
||||
updated_metadata = await self._metadata_sync.relink_metadata(
|
||||
file_path=file_path,
|
||||
metadata=local_metadata,
|
||||
model_id=int(model_id),
|
||||
model_version_id=int(model_version_id) if model_version_id else None,
|
||||
)
|
||||
|
||||
await self._service.scanner.update_single_model_cache(
|
||||
file_path, file_path, updated_metadata
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Model successfully re-linked to Civitai model {model_id}"
|
||||
+ (f" version {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
return web.json_response(
|
||||
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "preview_file":
|
||||
raise ValueError("Expected 'preview_file' field")
|
||||
content_type = field.headers.get("Content-Type", "image/png")
|
||||
content_disposition = field.headers.get("Content-Disposition", "")
|
||||
|
||||
original_filename = None
|
||||
import re
|
||||
|
||||
match = re.search(r'filename="(.*?)"', content_disposition)
|
||||
if match:
|
||||
original_filename = match.group(1)
|
||||
|
||||
preview_data = await field.read()
|
||||
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "model_path":
|
||||
raise ValueError("Expected 'model_path' field")
|
||||
model_path = (await field.read()).decode()
|
||||
|
||||
nsfw_level = 0
|
||||
field = await reader.next()
|
||||
if field and field.name == "nsfw_level":
|
||||
try:
|
||||
nsfw_level = int((await field.read()).decode())
|
||||
except (ValueError, TypeError):
|
||||
self._logger.warning("Invalid NSFW level format, using default 0")
|
||||
|
||||
result = await self._preview_service.replace_preview(
|
||||
model_path=model_path,
|
||||
preview_data=preview_data,
|
||||
content_type=content_type,
|
||||
original_filename=original_filename,
|
||||
nsfw_level=nsfw_level,
|
||||
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"preview_url": config.get_preview_static_url(result["preview_path"]),
|
||||
"preview_nsfw_level": result["preview_nsfw_level"],
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.Response(text="File path is required", status=400)
|
||||
|
||||
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
||||
|
||||
await self._metadata_sync.save_metadata_updates(
|
||||
file_path=file_path,
|
||||
updates=metadata_updates,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
if "model_name" in metadata_updates:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({"success": True})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def add_tags(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
new_tags = data.get("tags", [])
|
||||
|
||||
if not file_path:
|
||||
return web.Response(text="File path is required", status=400)
|
||||
|
||||
if not isinstance(new_tags, list):
|
||||
return web.Response(text="Tags must be a list", status=400)
|
||||
|
||||
tags = await self._tag_update_service.add_tags(
|
||||
file_path=file_path,
|
||||
new_tags=new_tags,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "tags": tags})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def rename_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
|
||||
@@ -226,7 +372,27 @@ class ModelManagementHandler:
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get("file_paths", [])
|
||||
|
||||
if not file_paths:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "No file paths provided for verification"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
results = await self._metadata_sync.verify_duplicate_hashes(
|
||||
file_paths=file_paths,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
hash_calculator=calculate_sha256,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, **results})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelQueryHandler:
|
||||
@@ -429,12 +595,39 @@ class ModelQueryHandler:
|
||||
class ModelDownloadHandler:
|
||||
"""Coordinate downloads and progress reporting."""
|
||||
|
||||
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
download_coordinator: DownloadCoordinator,
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_download_model(request)
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_coordinator.schedule_download(payload)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
error_message = str(exc)
|
||||
if "401" in error_message:
|
||||
self._logger.warning("Early access error (401): %s", error_message)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
self._logger.error("Error downloading model: %s", error_message)
|
||||
return web.json_response({"success": False, "error": error_message}, status=500)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -460,7 +653,12 @@ class ModelDownloadHandler:
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
||||
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||
result = await self._download_coordinator.schedule_download(data)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
@@ -470,8 +668,8 @@ class ModelDownloadHandler:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
||||
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
|
||||
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||
result = await self._download_coordinator.cancel_download(download_id)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
@@ -504,6 +702,7 @@ class ModelCivitaiHandler:
|
||||
validate_model_type: Callable[[str], bool],
|
||||
expected_model_types: Callable[[], str],
|
||||
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
||||
metadata_sync: MetadataSyncService,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._settings = settings_service
|
||||
@@ -513,6 +712,7 @@ class ModelCivitaiHandler:
|
||||
self._validate_model_type = validate_model_type
|
||||
self._expected_model_types = expected_model_types
|
||||
self._find_model_file = find_model_file
|
||||
self._metadata_sync = metadata_sync
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -545,7 +745,7 @@ class ModelCivitaiHandler:
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
result, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
|
||||
100
py/services/download_coordinator.py
Normal file
100
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Service wrapper for coordinating download lifecycle events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadCoordinator:
|
||||
"""Manage download scheduling, cancellation and introspection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager_factory: Callable[[], Awaitable],
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._download_manager_factory = download_manager_factory
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download using the provided payload."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
|
||||
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||
payload.setdefault("download_id", download_id)
|
||||
|
||||
async def progress_callback(progress: Any) -> None:
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "progress",
|
||||
"progress": progress,
|
||||
"download_id": download_id,
|
||||
},
|
||||
)
|
||||
|
||||
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||
model_version_id = self._parse_optional_int(
|
||||
payload.get("model_version_id"), "model_version_id"
|
||||
)
|
||||
|
||||
if model_id is None and model_version_id is None:
|
||||
raise ValueError(
|
||||
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||
)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=payload.get("model_root"),
|
||||
relative_path=payload.get("relative_path", ""),
|
||||
use_default_paths=payload.get("use_default_paths", False),
|
||||
progress_callback=progress_callback,
|
||||
download_id=download_id,
|
||||
source=payload.get("source"),
|
||||
)
|
||||
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Cancel an active download and emit a broadcast event."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "cancelled",
|
||||
"progress": 0,
|
||||
"download_id": download_id,
|
||||
"message": "Download cancelled by user",
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||
"""Return the active download map from the underlying manager."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
return await download_manager.get_active_downloads()
|
||||
|
||||
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||
"""Parse an optional integer from user input."""
|
||||
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||
|
||||
355
py/services/metadata_sync_service.py
Normal file
355
py/services/metadata_sync_service.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Services for synchronising metadata with remote providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.model_utils import determine_base_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataProviderProtocol:
|
||||
"""Subset of metadata provider interface consumed by the sync service."""
|
||||
|
||||
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
...
|
||||
|
||||
async def get_model_version(
|
||||
self, model_id: int, model_version_id: Optional[int]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
|
||||
class MetadataSyncService:
|
||||
"""High level orchestration for metadata synchronisation flows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
preview_service,
|
||||
settings: SettingsManager,
|
||||
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
|
||||
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._preview_service = preview_service
|
||||
self._settings = settings
|
||||
self._get_default_provider = default_metadata_provider_factory
|
||||
self._get_provider = metadata_provider_selector
|
||||
|
||||
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
|
||||
"""Load metadata JSON from disk, returning an empty structure when missing."""
|
||||
|
||||
if not os.path.exists(metadata_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||
return json.load(handle)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
|
||||
return {}
|
||||
|
||||
async def mark_not_found_on_civitai(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist the not-found flag for a metadata payload."""
|
||||
|
||||
local_metadata["from_civitai"] = False
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
|
||||
@staticmethod
|
||||
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
|
||||
"""Determine if the metadata originated from the CivitAI public API."""
|
||||
|
||||
if not isinstance(meta, dict):
|
||||
return False
|
||||
files = meta.get("files")
|
||||
images = meta.get("images")
|
||||
source = meta.get("source")
|
||||
return bool(files) and bool(images) and source != "archive_db"
|
||||
|
||||
async def update_model_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
civitai_metadata: Dict[str, Any],
|
||||
metadata_provider: Optional[MetadataProviderProtocol] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge remote metadata into the local record and persist the result."""
|
||||
|
||||
existing_civitai = local_metadata.get("civitai") or {}
|
||||
|
||||
if (
|
||||
civitai_metadata.get("source") == "archive_db"
|
||||
and self.is_civitai_api_metadata(existing_civitai)
|
||||
):
|
||||
logger.info(
|
||||
"Skip civitai update for %s (%s)",
|
||||
local_metadata.get("model_name", ""),
|
||||
existing_civitai.get("name", ""),
|
||||
)
|
||||
else:
|
||||
merged_civitai = existing_civitai.copy()
|
||||
merged_civitai.update(civitai_metadata)
|
||||
|
||||
if civitai_metadata.get("source") == "archive_db":
|
||||
model_name = civitai_metadata.get("model", {}).get("name", "")
|
||||
version_name = civitai_metadata.get("name", "")
|
||||
logger.info(
|
||||
"Recovered metadata from archive_db for deleted model: %s (%s)",
|
||||
model_name,
|
||||
version_name,
|
||||
)
|
||||
|
||||
if "trainedWords" in existing_civitai:
|
||||
existing_trained = existing_civitai.get("trainedWords", [])
|
||||
new_trained = civitai_metadata.get("trainedWords", [])
|
||||
merged_trained = list(set(existing_trained + new_trained))
|
||||
merged_civitai["trainedWords"] = merged_trained
|
||||
|
||||
local_metadata["civitai"] = merged_civitai
|
||||
|
||||
if "model" in civitai_metadata and civitai_metadata["model"]:
|
||||
model_data = civitai_metadata["model"]
|
||||
|
||||
if model_data.get("name"):
|
||||
local_metadata["model_name"] = model_data["name"]
|
||||
|
||||
if not local_metadata.get("modelDescription") and model_data.get("description"):
|
||||
local_metadata["modelDescription"] = model_data["description"]
|
||||
|
||||
if not local_metadata.get("tags") and model_data.get("tags"):
|
||||
local_metadata["tags"] = model_data["tags"]
|
||||
|
||||
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
|
||||
"creator"
|
||||
):
|
||||
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||
|
||||
local_metadata["base_model"] = determine_base_model(
|
||||
civitai_metadata.get("baseModel")
|
||||
)
|
||||
|
||||
await self._preview_service.ensure_preview_for_metadata(
|
||||
metadata_path, local_metadata, civitai_metadata.get("images", [])
|
||||
)
|
||||
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
return local_metadata
|
||||
|
||||
async def fetch_and_update_model(
|
||||
self,
|
||||
*,
|
||||
sha256: str,
|
||||
file_path: str,
|
||||
model_data: Dict[str, Any],
|
||||
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Fetch metadata for a model and update both disk and cache state."""
|
||||
|
||||
if not isinstance(model_data, dict):
|
||||
error = f"Invalid model_data type: {type(model_data)}"
|
||||
logger.error(error)
|
||||
return False, error
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||
|
||||
try:
|
||||
if model_data.get("civitai_deleted") is True:
|
||||
if not enable_archive or model_data.get("db_checked") is True:
|
||||
return (
|
||||
False,
|
||||
"CivitAI model is deleted and metadata archive DB is not enabled",
|
||||
)
|
||||
metadata_provider = await self._get_provider("sqlite")
|
||||
else:
|
||||
metadata_provider = await self._get_default_provider()
|
||||
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
if error == "Model not found":
|
||||
model_data["from_civitai"] = False
|
||||
model_data["civitai_deleted"] = True
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
data_to_save = model_data.copy()
|
||||
data_to_save.pop("folder", None)
|
||||
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||
|
||||
error_msg = (
|
||||
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
model_data["from_civitai"] = True
|
||||
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
local_metadata = model_data.copy()
|
||||
local_metadata.pop("folder", None)
|
||||
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
local_metadata,
|
||||
civitai_metadata,
|
||||
metadata_provider,
|
||||
)
|
||||
|
||||
update_payload = {
|
||||
"model_name": local_metadata.get("model_name"),
|
||||
"preview_url": local_metadata.get("preview_url"),
|
||||
"civitai": local_metadata.get("civitai"),
|
||||
}
|
||||
model_data.update(update_payload)
|
||||
|
||||
await update_cache_func(file_path, file_path, local_metadata)
|
||||
return True, None
|
||||
except KeyError as exc:
|
||||
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
error_msg = f"Error fetching metadata: {exc}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
async def fetch_metadata_by_sha(
|
||||
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Fetch metadata for a SHA256 hash from the configured provider."""
|
||||
|
||||
provider = metadata_provider or await self._get_default_provider()
|
||||
return await provider.get_model_by_hash(sha256)
|
||||
|
||||
async def relink_metadata(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
model_id: int,
|
||||
model_version_id: Optional[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Relink a local metadata record to a specific CivitAI model version."""
|
||||
|
||||
provider = await self._get_default_provider()
|
||||
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
|
||||
if not civitai_metadata:
|
||||
raise ValueError(
|
||||
f"Model version not found on CivitAI for ID: {model_id}"
|
||||
+ (f" with version: {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
|
||||
primary_model_file: Optional[Dict[str, Any]] = None
|
||||
for file_info in civitai_metadata.get("files", []):
|
||||
if file_info.get("primary", False) and file_info.get("type") == "Model":
|
||||
primary_model_file = file_info
|
||||
break
|
||||
|
||||
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
|
||||
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
metadata,
|
||||
civitai_metadata,
|
||||
provider,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
async def save_metadata_updates(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
updates: Dict[str, Any],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply metadata updates and persist to disk and cache."""
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
|
||||
metadata[key].update(value)
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
await update_cache(file_path, file_path, metadata)
|
||||
|
||||
if "model_name" in updates:
|
||||
logger.debug("Metadata update touched model_name; cache resort required")
|
||||
|
||||
return metadata
|
||||
|
||||
async def verify_duplicate_hashes(
|
||||
self,
|
||||
*,
|
||||
file_paths: Iterable[str],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
hash_calculator: Callable[[str], Awaitable[str]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Verify a collection of files share the same SHA256 hash."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for verification")
|
||||
|
||||
results = {
|
||||
"verified_as_duplicates": True,
|
||||
"mismatched_files": [],
|
||||
"new_hash_map": {},
|
||||
}
|
||||
|
||||
expected_hash: Optional[str] = None
|
||||
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
|
||||
first_metadata = await metadata_loader(first_metadata_path)
|
||||
if first_metadata and "sha256" in first_metadata:
|
||||
expected_hash = first_metadata["sha256"].lower()
|
||||
|
||||
for path in file_paths:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
actual_hash = await hash_calculator(path)
|
||||
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
stored_hash = metadata.get("sha256", "").lower()
|
||||
|
||||
if not expected_hash:
|
||||
expected_hash = stored_hash
|
||||
|
||||
if actual_hash != expected_hash:
|
||||
results["verified_as_duplicates"] = False
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = actual_hash
|
||||
|
||||
if actual_hash != stored_hash:
|
||||
metadata["sha256"] = actual_hash
|
||||
await self._metadata_manager.save_metadata(path, metadata)
|
||||
await update_cache(path, path, metadata)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error verifying hash for %s: %s", path, exc)
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = "error_calculating_hash"
|
||||
results["verified_as_duplicates"] = False
|
||||
|
||||
return results
|
||||
|
||||
168
py/services/preview_asset_service.py
Normal file
168
py/services/preview_asset_service.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Service for processing preview assets for models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewAssetService:
|
||||
"""Manage fetching and persisting preview assets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
downloader_factory: Callable[[], Awaitable],
|
||||
exif_utils,
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._downloader_factory = downloader_factory
|
||||
self._exif_utils = exif_utils
|
||||
|
||||
async def ensure_preview_for_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, object],
|
||||
images: Sequence[Dict[str, object]] | None,
|
||||
) -> None:
|
||||
"""Ensure preview assets exist for the supplied metadata entry."""
|
||||
|
||||
if local_metadata.get("preview_url") and os.path.exists(
|
||||
str(local_metadata["preview_url"])
|
||||
):
|
||||
return
|
||||
|
||||
if not images:
|
||||
return
|
||||
|
||||
first_preview = images[0]
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_dir = os.path.dirname(metadata_path)
|
||||
is_video = first_preview.get("type") == "video"
|
||||
|
||||
if is_video:
|
||||
extension = ".mp4"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(
|
||||
first_preview["url"], preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
else:
|
||||
extension = ".webp"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
first_preview["url"], use_auth=False
|
||||
)
|
||||
if not success:
|
||||
return
|
||||
|
||||
try:
|
||||
optimized_data, _ = self._exif_utils.optimize_image(
|
||||
image_data=content,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(optimized_data)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error optimizing preview image: %s", exc)
|
||||
try:
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(content)
|
||||
except Exception as save_exc:
|
||||
logger.error("Error saving preview image: %s", save_exc)
|
||||
return
|
||||
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
|
||||
async def replace_preview(
|
||||
self,
|
||||
*,
|
||||
model_path: str,
|
||||
preview_data: bytes,
|
||||
content_type: str,
|
||||
original_filename: Optional[str],
|
||||
nsfw_level: int,
|
||||
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
) -> Dict[str, object]:
|
||||
"""Replace an existing preview asset for a model."""
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
folder = os.path.dirname(model_path)
|
||||
|
||||
extension, optimized_data = await self._convert_preview(
|
||||
preview_data, content_type, original_filename
|
||||
)
|
||||
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
existing_preview = os.path.join(folder, base_name + ext)
|
||||
if os.path.exists(existing_preview):
|
||||
try:
|
||||
os.remove(existing_preview)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.warning(
|
||||
"Failed to delete existing preview %s: %s", existing_preview, exc
|
||||
)
|
||||
|
||||
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(optimized_data)
|
||||
|
||||
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
metadata["preview_url"] = preview_path
|
||||
metadata["preview_nsfw_level"] = nsfw_level
|
||||
await self._metadata_manager.save_metadata(model_path, metadata)
|
||||
|
||||
await update_preview_in_cache(model_path, preview_path, nsfw_level)
|
||||
|
||||
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
|
||||
|
||||
async def _convert_preview(
|
||||
self, data: bytes, content_type: str, original_filename: Optional[str]
|
||||
) -> tuple[str, bytes]:
|
||||
"""Convert preview bytes to the persisted representation."""
|
||||
|
||||
if content_type.startswith("video/"):
|
||||
extension = self._resolve_video_extension(content_type, original_filename)
|
||||
return extension, data
|
||||
|
||||
original_ext = (original_filename or "").lower()
|
||||
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
|
||||
return ".gif", data
|
||||
|
||||
optimized_data, _ = self._exif_utils.optimize_image(
|
||||
image_data=data,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
return ".webp", optimized_data
|
||||
|
||||
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
|
||||
"""Infer the best extension for a video preview."""
|
||||
|
||||
if original_filename:
|
||||
extension = os.path.splitext(original_filename)[1].lower()
|
||||
if extension in {".mp4", ".webm", ".mov", ".avi"}:
|
||||
return extension
|
||||
|
||||
if "webm" in content_type:
|
||||
return ".webm"
|
||||
return ".mp4"
|
||||
|
||||
47
py/services/tag_update_service.py
Normal file
47
py/services/tag_update_service.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Service for updating tag collections on metadata records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from typing import Awaitable, Callable, Dict, List, Sequence
|
||||
|
||||
|
||||
class TagUpdateService:
|
||||
"""Encapsulate tag manipulation for models."""
|
||||
|
||||
def __init__(self, *, metadata_manager) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
|
||||
async def add_tags(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
new_tags: Sequence[str],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
|
||||
) -> List[str]:
|
||||
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
|
||||
|
||||
base, _ = os.path.splitext(file_path)
|
||||
metadata_path = f"{base}.metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
|
||||
existing_tags = list(metadata.get("tags", []))
|
||||
existing_lower = [tag.lower() for tag in existing_tags]
|
||||
|
||||
tags_added: List[str] = []
|
||||
for tag in new_tags:
|
||||
if isinstance(tag, str) and tag.strip():
|
||||
normalized = tag.strip()
|
||||
if normalized.lower() not in existing_lower:
|
||||
existing_tags.append(normalized)
|
||||
existing_lower.append(normalized.lower())
|
||||
tags_added.append(normalized)
|
||||
|
||||
metadata["tags"] = existing_tags
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
await update_cache(file_path, file_path, metadata)
|
||||
|
||||
return existing_tags
|
||||
|
||||
@@ -1,14 +1,34 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
from ..recipes.constants import GEN_PARAM_KEYS
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..recipes.constants import GEN_PARAM_KEYS
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
|
||||
_metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=_preview_service,
|
||||
settings=settings,
|
||||
default_metadata_provider_factory=get_default_metadata_provider,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
|
||||
|
||||
class MetadataUpdater:
|
||||
"""Handles updating model metadata related to example images"""
|
||||
|
||||
@@ -53,11 +73,11 @@ class MetadataUpdater:
|
||||
async def update_cache_func(old_path, new_path, metadata):
|
||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||
|
||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
model_hash,
|
||||
file_path,
|
||||
model_data,
|
||||
update_cache_func
|
||||
success, error = await _metadata_sync_service.fetch_and_update_model(
|
||||
sha256=model_hash,
|
||||
file_path=file_path,
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache_func,
|
||||
)
|
||||
|
||||
if success:
|
||||
|
||||
@@ -4,5 +4,8 @@ testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
# Register async marker for coroutine-style tests
|
||||
markers =
|
||||
asyncio: execute test within asyncio event loop
|
||||
# Skip problematic directories to avoid import conflicts
|
||||
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
||||
@@ -1,6 +1,8 @@
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
import asyncio
|
||||
import inspect
|
||||
from unittest import mock
|
||||
import sys
|
||||
|
||||
@@ -39,6 +41,13 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||
sys.modules['nodes'] = nodes_mock
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
if inspect.iscoroutinefunction(pyfuncitem.function):
|
||||
asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs))
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHashIndex:
|
||||
"""Minimal hash index stub mirroring the scanner contract."""
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from py.services.base_model_service import BaseModelService
|
||||
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
|
||||
from py.utils.models import BaseModelMetadata
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
|
||||
def import_from(module_name: str):
|
||||
existing = sys.modules.get("py")
|
||||
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||
sys.modules.pop("py", None)
|
||||
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
module.__path__ = [str(ROOT / "py")]
|
||||
sys.modules["py"] = module
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
BaseModelService = import_from("py.services.base_model_service").BaseModelService
|
||||
model_query_module = import_from("py.services.model_query")
|
||||
ModelCacheRepository = model_query_module.ModelCacheRepository
|
||||
ModelFilterSet = model_query_module.ModelFilterSet
|
||||
SearchStrategy = model_query_module.SearchStrategy
|
||||
SortParams = model_query_module.SortParams
|
||||
BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata
|
||||
|
||||
|
||||
class StubSettings:
|
||||
|
||||
273
tests/services/test_route_support_services.py
Normal file
273
tests/services/test_route_support_services.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def import_from(module_name: str):
|
||||
existing = sys.modules.get("py")
|
||||
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||
sys.modules.pop("py", None)
|
||||
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
module.__path__ = [str(ROOT / "py")]
|
||||
sys.modules["py"] = module
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator
|
||||
MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService
|
||||
PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService
|
||||
TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService
|
||||
|
||||
|
||||
class DummySettings:
|
||||
def __init__(self, values: Dict[str, Any] | None = None) -> None:
|
||||
self._values = values or {}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
class RecordingMetadataManager:
|
||||
def __init__(self) -> None:
|
||||
self.saved: List[tuple[str, Dict[str, Any]]] = []
|
||||
|
||||
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
|
||||
self.saved.append((path, json.loads(json.dumps(metadata))))
|
||||
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||
Path(metadata_path).write_text(json.dumps(metadata))
|
||||
return True
|
||||
|
||||
|
||||
class RecordingPreviewService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
|
||||
|
||||
async def ensure_preview_for_metadata(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any], images
|
||||
) -> None:
|
||||
self.calls.append((metadata_path, list(images or [])))
|
||||
local_metadata["preview_url"] = "preview.webp"
|
||||
local_metadata["preview_nsfw_level"] = 1
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self.payload = payload
|
||||
|
||||
async def get_model_by_hash(self, sha256: str):
|
||||
return self.payload, None
|
||||
|
||||
async def get_model_version(self, model_id: int, model_version_id: int | None):
|
||||
return self.payload
|
||||
|
||||
|
||||
class FakeExifUtils:
|
||||
@staticmethod
|
||||
def optimize_image(**kwargs):
|
||||
return kwargs["image_data"], {}
|
||||
|
||||
|
||||
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
|
||||
manager = RecordingMetadataManager()
|
||||
preview = RecordingPreviewService()
|
||||
provider = DummyProvider({
|
||||
"baseModel": "SD15",
|
||||
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
|
||||
"trainedWords": ["word"],
|
||||
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
|
||||
})
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=manager,
|
||||
preview_service=preview,
|
||||
settings=DummySettings(),
|
||||
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||
)
|
||||
|
||||
metadata_path = str(tmp_path / "model.metadata.json")
|
||||
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
|
||||
|
||||
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
|
||||
|
||||
assert updated["model_name"] == "Merged"
|
||||
assert updated["modelDescription"] == "desc"
|
||||
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
|
||||
assert manager.saved
|
||||
assert preview.calls
|
||||
|
||||
|
||||
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
|
||||
manager = RecordingMetadataManager()
|
||||
preview = RecordingPreviewService()
|
||||
provider = DummyProvider({
|
||||
"baseModel": "SDXL",
|
||||
"model": {"name": "Updated"},
|
||||
"images": [],
|
||||
})
|
||||
|
||||
update_cache_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||
update_cache_calls.append({"original": original, "metadata": metadata})
|
||||
return True
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=manager,
|
||||
preview_service=preview,
|
||||
settings=DummySettings(),
|
||||
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||
)
|
||||
|
||||
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
|
||||
success, error = asyncio.run(
|
||||
service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert update_cache_calls
|
||||
assert manager.saved
|
||||
|
||||
|
||||
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
|
||||
metadata_path = tmp_path / "sample.metadata.json"
|
||||
metadata_path.write_text(json.dumps({}))
|
||||
|
||||
async def metadata_loader(path: str) -> Dict[str, Any]:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
manager = RecordingMetadataManager()
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=manager,
|
||||
downloader_factory=lambda: asyncio.sleep(0, result=None),
|
||||
exif_utils=FakeExifUtils(),
|
||||
)
|
||||
|
||||
preview_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
|
||||
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
|
||||
return True
|
||||
|
||||
model_path = str(tmp_path / "sample.safetensors")
|
||||
Path(model_path).write_bytes(b"model")
|
||||
|
||||
result = asyncio.run(
|
||||
service.replace_preview(
|
||||
model_path=model_path,
|
||||
preview_data=b"image-bytes",
|
||||
content_type="image/png",
|
||||
original_filename="preview.png",
|
||||
nsfw_level=2,
|
||||
update_preview_in_cache=update_preview,
|
||||
metadata_loader=metadata_loader,
|
||||
)
|
||||
)
|
||||
|
||||
assert result["preview_nsfw_level"] == 2
|
||||
assert preview_calls
|
||||
saved_metadata = json.loads(metadata_path.read_text())
|
||||
assert saved_metadata["preview_nsfw_level"] == 2
|
||||
|
||||
|
||||
def test_download_coordinator_emits_progress() -> None:
|
||||
class WSStub:
|
||||
def __init__(self) -> None:
|
||||
self.progress_events: List[Dict[str, Any]] = []
|
||||
self.counter = 0
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
self.counter += 1
|
||||
return f"dl-{self.counter}"
|
||||
|
||||
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
||||
self.progress_events.append({"id": download_id, **payload})
|
||||
|
||||
class DownloadManagerStub:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
||||
self.calls.append(kwargs)
|
||||
await kwargs["progress_callback"](10)
|
||||
return {"success": True}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
return {"success": True, "download_id": download_id}
|
||||
|
||||
async def get_active_downloads(self) -> Dict[str, Any]:
|
||||
return {"active": []}
|
||||
|
||||
ws_stub = WSStub()
|
||||
manager_stub = DownloadManagerStub()
|
||||
|
||||
coordinator = DownloadCoordinator(
|
||||
ws_manager=ws_stub,
|
||||
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
|
||||
)
|
||||
|
||||
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
|
||||
|
||||
assert result["success"] is True
|
||||
assert manager_stub.calls
|
||||
assert ws_stub.progress_events
|
||||
|
||||
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
||||
assert cancel_result["success"] is True
|
||||
|
||||
active = asyncio.run(coordinator.list_active_downloads())
|
||||
assert active == {"active": []}
|
||||
|
||||
|
||||
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
|
||||
|
||||
async def loader(path: str) -> Dict[str, Any]:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
manager = RecordingMetadataManager()
|
||||
|
||||
service = TagUpdateService(metadata_manager=manager)
|
||||
|
||||
cache_updates: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||
cache_updates.append(metadata)
|
||||
return True
|
||||
|
||||
tags = asyncio.run(
|
||||
service.add_tags(
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
new_tags=["New", "existing"],
|
||||
metadata_loader=loader,
|
||||
update_cache=update_cache,
|
||||
)
|
||||
)
|
||||
|
||||
assert tags == ["Existing", "New"]
|
||||
assert manager.saved
|
||||
assert cache_updates
|
||||
Reference in New Issue
Block a user