mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
refactor(routes): extract route utilities into services
This commit is contained in:
@@ -8,12 +8,20 @@ import jinja2
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.metadata_service import get_default_metadata_provider
|
from ..services.download_coordinator import DownloadCoordinator
|
||||||
|
from ..services.downloader import get_downloader
|
||||||
|
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
|
from ..services.metadata_sync_service import MetadataSyncService
|
||||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||||
|
from ..services.preview_asset_service import PreviewAssetService
|
||||||
|
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.settings_manager import settings as default_settings
|
from ..services.settings_manager import settings as default_settings
|
||||||
|
from ..services.tag_update_service import TagUpdateService
|
||||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||||
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
||||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
from ..utils.exif_utils import ExifUtils
|
||||||
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||||
from .handlers.model_handlers import (
|
from .handlers.model_handlers import (
|
||||||
@@ -64,6 +72,24 @@ class BaseModelRoutes(ABC):
|
|||||||
self._handler_set: ModelHandlerSet | None = None
|
self._handler_set: ModelHandlerSet | None = None
|
||||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||||
|
|
||||||
|
self._preview_service = PreviewAssetService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
downloader_factory=get_downloader,
|
||||||
|
exif_utils=ExifUtils,
|
||||||
|
)
|
||||||
|
self._metadata_sync_service = MetadataSyncService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
preview_service=self._preview_service,
|
||||||
|
settings=settings_service,
|
||||||
|
default_metadata_provider_factory=metadata_provider_factory,
|
||||||
|
metadata_provider_selector=get_metadata_provider,
|
||||||
|
)
|
||||||
|
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||||
|
self._download_coordinator = DownloadCoordinator(
|
||||||
|
ws_manager=self._ws_manager,
|
||||||
|
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||||
|
)
|
||||||
|
|
||||||
if service is not None:
|
if service is not None:
|
||||||
self.attach_service(service)
|
self.attach_service(service)
|
||||||
|
|
||||||
@@ -98,9 +124,19 @@ class BaseModelRoutes(ABC):
|
|||||||
parse_specific_params=self._parse_specific_params,
|
parse_specific_params=self._parse_specific_params,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
management = ModelManagementHandler(service=service, logger=logger)
|
management = ModelManagementHandler(
|
||||||
|
service=service,
|
||||||
|
logger=logger,
|
||||||
|
metadata_sync=self._metadata_sync_service,
|
||||||
|
preview_service=self._preview_service,
|
||||||
|
tag_update_service=self._tag_update_service,
|
||||||
|
)
|
||||||
query = ModelQueryHandler(service=service, logger=logger)
|
query = ModelQueryHandler(service=service, logger=logger)
|
||||||
download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger)
|
download = ModelDownloadHandler(
|
||||||
|
ws_manager=self._ws_manager,
|
||||||
|
logger=logger,
|
||||||
|
download_coordinator=self._download_coordinator,
|
||||||
|
)
|
||||||
civitai = ModelCivitaiHandler(
|
civitai = ModelCivitaiHandler(
|
||||||
service=service,
|
service=service,
|
||||||
settings_service=self._settings,
|
settings_service=self._settings,
|
||||||
@@ -110,6 +146,7 @@ class BaseModelRoutes(ABC):
|
|||||||
validate_model_type=self._validate_civitai_model_type,
|
validate_model_type=self._validate_civitai_model_type,
|
||||||
expected_model_types=self._get_expected_model_types,
|
expected_model_types=self._get_expected_model_types,
|
||||||
find_model_file=self._find_model_file,
|
find_model_file=self._find_model_file,
|
||||||
|
metadata_sync=self._metadata_sync_service,
|
||||||
)
|
)
|
||||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||||
auto_organize = ModelAutoOrganizeHandler(
|
auto_organize = ModelAutoOrganizeHandler(
|
||||||
|
|||||||
@@ -4,16 +4,23 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
|
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
from ...config import config
|
||||||
|
from ...services.download_coordinator import DownloadCoordinator
|
||||||
|
from ...services.metadata_sync_service import MetadataSyncService
|
||||||
from ...services.model_file_service import ModelFileService, ModelMoveService
|
from ...services.model_file_service import ModelFileService, ModelMoveService
|
||||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
from ...services.preview_asset_service import PreviewAssetService
|
||||||
from ...services.websocket_manager import WebSocketManager
|
|
||||||
from ...services.settings_manager import SettingsManager
|
from ...services.settings_manager import SettingsManager
|
||||||
|
from ...services.tag_update_service import TagUpdateService
|
||||||
|
from ...services.websocket_manager import WebSocketManager
|
||||||
|
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||||
|
from ...utils.file_utils import calculate_sha256
|
||||||
from ...utils.routes_common import ModelRouteUtils
|
from ...utils.routes_common import ModelRouteUtils
|
||||||
|
|
||||||
|
|
||||||
@@ -168,9 +175,20 @@ class ModelListingHandler:
|
|||||||
class ModelManagementHandler:
|
class ModelManagementHandler:
|
||||||
"""Handle mutation operations on models."""
|
"""Handle mutation operations on models."""
|
||||||
|
|
||||||
def __init__(self, *, service, logger: logging.Logger) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
service,
|
||||||
|
logger: logging.Logger,
|
||||||
|
metadata_sync: MetadataSyncService,
|
||||||
|
preview_service: PreviewAssetService,
|
||||||
|
tag_update_service: TagUpdateService,
|
||||||
|
) -> None:
|
||||||
self._service = service
|
self._service = service
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
self._metadata_sync = metadata_sync
|
||||||
|
self._preview_service = preview_service
|
||||||
|
self._tag_update_service = tag_update_service
|
||||||
|
|
||||||
async def delete_model(self, request: web.Request) -> web.Response:
|
async def delete_model(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
|
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
|
||||||
@@ -192,7 +210,7 @@ class ModelManagementHandler:
|
|||||||
if not model_data.get("sha256"):
|
if not model_data.get("sha256"):
|
||||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||||
|
|
||||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||||
sha256=model_data["sha256"],
|
sha256=model_data["sha256"],
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
model_data=model_data,
|
model_data=model_data,
|
||||||
@@ -208,16 +226,144 @@ class ModelManagementHandler:
|
|||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get("file_path")
|
||||||
|
model_id = data.get("model_id")
|
||||||
|
model_version_id = data.get("model_version_id")
|
||||||
|
|
||||||
|
if not file_path or model_id is None:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Both file_path and model_id are required"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
|
||||||
|
|
||||||
|
updated_metadata = await self._metadata_sync.relink_metadata(
|
||||||
|
file_path=file_path,
|
||||||
|
metadata=local_metadata,
|
||||||
|
model_id=int(model_id),
|
||||||
|
model_version_id=int(model_version_id) if model_version_id else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._service.scanner.update_single_model_cache(
|
||||||
|
file_path, file_path, updated_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
f"Model successfully re-linked to Civitai model {model_id}"
|
||||||
|
+ (f" version {model_version_id}" if model_version_id else "")
|
||||||
|
)
|
||||||
|
return web.json_response(
|
||||||
|
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
|
try:
|
||||||
|
reader = await request.multipart()
|
||||||
|
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None or field.name != "preview_file":
|
||||||
|
raise ValueError("Expected 'preview_file' field")
|
||||||
|
content_type = field.headers.get("Content-Type", "image/png")
|
||||||
|
content_disposition = field.headers.get("Content-Disposition", "")
|
||||||
|
|
||||||
|
original_filename = None
|
||||||
|
import re
|
||||||
|
|
||||||
|
match = re.search(r'filename="(.*?)"', content_disposition)
|
||||||
|
if match:
|
||||||
|
original_filename = match.group(1)
|
||||||
|
|
||||||
|
preview_data = await field.read()
|
||||||
|
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None or field.name != "model_path":
|
||||||
|
raise ValueError("Expected 'model_path' field")
|
||||||
|
model_path = (await field.read()).decode()
|
||||||
|
|
||||||
|
nsfw_level = 0
|
||||||
|
field = await reader.next()
|
||||||
|
if field and field.name == "nsfw_level":
|
||||||
|
try:
|
||||||
|
nsfw_level = int((await field.read()).decode())
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
self._logger.warning("Invalid NSFW level format, using default 0")
|
||||||
|
|
||||||
|
result = await self._preview_service.replace_preview(
|
||||||
|
model_path=model_path,
|
||||||
|
preview_data=preview_data,
|
||||||
|
content_type=content_type,
|
||||||
|
original_filename=original_filename,
|
||||||
|
nsfw_level=nsfw_level,
|
||||||
|
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
||||||
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"preview_url": config.get_preview_static_url(result["preview_path"]),
|
||||||
|
"preview_nsfw_level": result["preview_nsfw_level"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
||||||
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get("file_path")
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text="File path is required", status=400)
|
||||||
|
|
||||||
|
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
||||||
|
|
||||||
|
await self._metadata_sync.save_metadata_updates(
|
||||||
|
file_path=file_path,
|
||||||
|
updates=metadata_updates,
|
||||||
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
|
update_cache=self._service.scanner.update_single_model_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "model_name" in metadata_updates:
|
||||||
|
cache = await self._service.scanner.get_cached_data()
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
||||||
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
async def add_tags(self, request: web.Request) -> web.Response:
|
async def add_tags(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get("file_path")
|
||||||
|
new_tags = data.get("tags", [])
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text="File path is required", status=400)
|
||||||
|
|
||||||
|
if not isinstance(new_tags, list):
|
||||||
|
return web.Response(text="Tags must be a list", status=400)
|
||||||
|
|
||||||
|
tags = await self._tag_update_service.add_tags(
|
||||||
|
file_path=file_path,
|
||||||
|
new_tags=new_tags,
|
||||||
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
|
update_cache=self._service.scanner.update_single_model_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({"success": True, "tags": tags})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
||||||
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
async def rename_model(self, request: web.Request) -> web.Response:
|
async def rename_model(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
|
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
|
||||||
@@ -226,7 +372,27 @@ class ModelManagementHandler:
|
|||||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
|
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
|
||||||
|
|
||||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_paths = data.get("file_paths", [])
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "No file paths provided for verification"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self._metadata_sync.verify_duplicate_hashes(
|
||||||
|
file_paths=file_paths,
|
||||||
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
|
hash_calculator=calculate_sha256,
|
||||||
|
update_cache=self._service.scanner.update_single_model_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({"success": True, **results})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
class ModelQueryHandler:
|
class ModelQueryHandler:
|
||||||
@@ -429,12 +595,39 @@ class ModelQueryHandler:
|
|||||||
class ModelDownloadHandler:
|
class ModelDownloadHandler:
|
||||||
"""Coordinate downloads and progress reporting."""
|
"""Coordinate downloads and progress reporting."""
|
||||||
|
|
||||||
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ws_manager: WebSocketManager,
|
||||||
|
logger: logging.Logger,
|
||||||
|
download_coordinator: DownloadCoordinator,
|
||||||
|
) -> None:
|
||||||
self._ws_manager = ws_manager
|
self._ws_manager = ws_manager
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
self._download_coordinator = download_coordinator
|
||||||
|
|
||||||
async def download_model(self, request: web.Request) -> web.Response:
|
async def download_model(self, request: web.Request) -> web.Response:
|
||||||
return await ModelRouteUtils.handle_download_model(request)
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
result = await self._download_coordinator.schedule_download(payload)
|
||||||
|
if not result.get("success", False):
|
||||||
|
return web.json_response(result, status=500)
|
||||||
|
return web.json_response(result)
|
||||||
|
except ValueError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
error_message = str(exc)
|
||||||
|
if "401" in error_message:
|
||||||
|
self._logger.warning("Early access error (401): %s", error_message)
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
|
||||||
|
},
|
||||||
|
status=401,
|
||||||
|
)
|
||||||
|
self._logger.error("Error downloading model: %s", error_message)
|
||||||
|
return web.json_response({"success": False, "error": error_message}, status=500)
|
||||||
|
|
||||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
@@ -460,7 +653,12 @@ class ModelDownloadHandler:
|
|||||||
future.set_result(data)
|
future.set_result(data)
|
||||||
|
|
||||||
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
||||||
return await ModelRouteUtils.handle_download_model(mock_request)
|
result = await self._download_coordinator.schedule_download(data)
|
||||||
|
if not result.get("success", False):
|
||||||
|
return web.json_response(result, status=500)
|
||||||
|
return web.json_response(result)
|
||||||
|
except ValueError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
||||||
return web.Response(status=500, text=str(exc))
|
return web.Response(status=500, text=str(exc))
|
||||||
@@ -470,8 +668,8 @@ class ModelDownloadHandler:
|
|||||||
download_id = request.query.get("download_id")
|
download_id = request.query.get("download_id")
|
||||||
if not download_id:
|
if not download_id:
|
||||||
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
||||||
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
|
result = await self._download_coordinator.cancel_download(download_id)
|
||||||
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
return web.json_response(result)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
@@ -504,6 +702,7 @@ class ModelCivitaiHandler:
|
|||||||
validate_model_type: Callable[[str], bool],
|
validate_model_type: Callable[[str], bool],
|
||||||
expected_model_types: Callable[[], str],
|
expected_model_types: Callable[[], str],
|
||||||
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
||||||
|
metadata_sync: MetadataSyncService,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._service = service
|
self._service = service
|
||||||
self._settings = settings_service
|
self._settings = settings_service
|
||||||
@@ -513,6 +712,7 @@ class ModelCivitaiHandler:
|
|||||||
self._validate_model_type = validate_model_type
|
self._validate_model_type = validate_model_type
|
||||||
self._expected_model_types = expected_model_types
|
self._expected_model_types = expected_model_types
|
||||||
self._find_model_file = find_model_file
|
self._find_model_file = find_model_file
|
||||||
|
self._metadata_sync = metadata_sync
|
||||||
|
|
||||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
@@ -545,7 +745,7 @@ class ModelCivitaiHandler:
|
|||||||
for model in to_process:
|
for model in to_process:
|
||||||
try:
|
try:
|
||||||
original_name = model.get("model_name")
|
original_name = model.get("model_name")
|
||||||
result, error = await ModelRouteUtils.fetch_and_update_model(
|
result, error = await self._metadata_sync.fetch_and_update_model(
|
||||||
sha256=model["sha256"],
|
sha256=model["sha256"],
|
||||||
file_path=model["file_path"],
|
file_path=model["file_path"],
|
||||||
model_data=model,
|
model_data=model,
|
||||||
|
|||||||
100
py/services/download_coordinator.py
Normal file
100
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""Service wrapper for coordinating download lifecycle events."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadCoordinator:
|
||||||
|
"""Manage download scheduling, cancellation and introspection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ws_manager,
|
||||||
|
download_manager_factory: Callable[[], Awaitable],
|
||||||
|
) -> None:
|
||||||
|
self._ws_manager = ws_manager
|
||||||
|
self._download_manager_factory = download_manager_factory
|
||||||
|
|
||||||
|
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Schedule a download using the provided payload."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
|
||||||
|
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||||
|
payload.setdefault("download_id", download_id)
|
||||||
|
|
||||||
|
async def progress_callback(progress: Any) -> None:
|
||||||
|
await self._ws_manager.broadcast_download_progress(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"status": "progress",
|
||||||
|
"progress": progress,
|
||||||
|
"download_id": download_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||||
|
model_version_id = self._parse_optional_int(
|
||||||
|
payload.get("model_version_id"), "model_version_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_id is None and model_version_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await download_manager.download_from_civitai(
|
||||||
|
model_id=model_id,
|
||||||
|
model_version_id=model_version_id,
|
||||||
|
save_dir=payload.get("model_root"),
|
||||||
|
relative_path=payload.get("relative_path", ""),
|
||||||
|
use_default_paths=payload.get("use_default_paths", False),
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
download_id=download_id,
|
||||||
|
source=payload.get("source"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result["download_id"] = download_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
"""Cancel an active download and emit a broadcast event."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
result = await download_manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
await self._ws_manager.broadcast_download_progress(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"status": "cancelled",
|
||||||
|
"progress": 0,
|
||||||
|
"download_id": download_id,
|
||||||
|
"message": "Download cancelled by user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||||
|
"""Return the active download map from the underlying manager."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
return await download_manager.get_active_downloads()
|
||||||
|
|
||||||
|
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||||
|
"""Parse an optional integer from user input."""
|
||||||
|
|
||||||
|
if value is None or value == "":
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||||
|
|
||||||
355
py/services/metadata_sync_service.py
Normal file
355
py/services/metadata_sync_service.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
"""Services for synchronising metadata with remote providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from ..services.settings_manager import SettingsManager
|
||||||
|
from ..utils.model_utils import determine_base_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataProviderProtocol:
|
||||||
|
"""Subset of metadata provider interface consumed by the sync service."""
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_model_version(
|
||||||
|
self, model_id: int, model_version_id: Optional[int]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataSyncService:
|
||||||
|
"""High level orchestration for metadata synchronisation flows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata_manager,
|
||||||
|
preview_service,
|
||||||
|
settings: SettingsManager,
|
||||||
|
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
|
||||||
|
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
|
||||||
|
) -> None:
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
self._preview_service = preview_service
|
||||||
|
self._settings = settings
|
||||||
|
self._get_default_provider = default_metadata_provider_factory
|
||||||
|
self._get_provider = metadata_provider_selector
|
||||||
|
|
||||||
|
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
|
||||||
|
"""Load metadata JSON from disk, returning an empty structure when missing."""
|
||||||
|
|
||||||
|
if not os.path.exists(metadata_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||||
|
return json.load(handle)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mark_not_found_on_civitai(
|
||||||
|
self, metadata_path: str, local_metadata: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Persist the not-found flag for a metadata payload."""
|
||||||
|
|
||||||
|
local_metadata["from_civitai"] = False
|
||||||
|
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
|
||||||
|
"""Determine if the metadata originated from the CivitAI public API."""
|
||||||
|
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return False
|
||||||
|
files = meta.get("files")
|
||||||
|
images = meta.get("images")
|
||||||
|
source = meta.get("source")
|
||||||
|
return bool(files) and bool(images) and source != "archive_db"
|
||||||
|
|
||||||
|
async def update_model_metadata(
|
||||||
|
self,
|
||||||
|
metadata_path: str,
|
||||||
|
local_metadata: Dict[str, Any],
|
||||||
|
civitai_metadata: Dict[str, Any],
|
||||||
|
metadata_provider: Optional[MetadataProviderProtocol] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge remote metadata into the local record and persist the result."""
|
||||||
|
|
||||||
|
existing_civitai = local_metadata.get("civitai") or {}
|
||||||
|
|
||||||
|
if (
|
||||||
|
civitai_metadata.get("source") == "archive_db"
|
||||||
|
and self.is_civitai_api_metadata(existing_civitai)
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Skip civitai update for %s (%s)",
|
||||||
|
local_metadata.get("model_name", ""),
|
||||||
|
existing_civitai.get("name", ""),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged_civitai = existing_civitai.copy()
|
||||||
|
merged_civitai.update(civitai_metadata)
|
||||||
|
|
||||||
|
if civitai_metadata.get("source") == "archive_db":
|
||||||
|
model_name = civitai_metadata.get("model", {}).get("name", "")
|
||||||
|
version_name = civitai_metadata.get("name", "")
|
||||||
|
logger.info(
|
||||||
|
"Recovered metadata from archive_db for deleted model: %s (%s)",
|
||||||
|
model_name,
|
||||||
|
version_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "trainedWords" in existing_civitai:
|
||||||
|
existing_trained = existing_civitai.get("trainedWords", [])
|
||||||
|
new_trained = civitai_metadata.get("trainedWords", [])
|
||||||
|
merged_trained = list(set(existing_trained + new_trained))
|
||||||
|
merged_civitai["trainedWords"] = merged_trained
|
||||||
|
|
||||||
|
local_metadata["civitai"] = merged_civitai
|
||||||
|
|
||||||
|
if "model" in civitai_metadata and civitai_metadata["model"]:
|
||||||
|
model_data = civitai_metadata["model"]
|
||||||
|
|
||||||
|
if model_data.get("name"):
|
||||||
|
local_metadata["model_name"] = model_data["name"]
|
||||||
|
|
||||||
|
if not local_metadata.get("modelDescription") and model_data.get("description"):
|
||||||
|
local_metadata["modelDescription"] = model_data["description"]
|
||||||
|
|
||||||
|
if not local_metadata.get("tags") and model_data.get("tags"):
|
||||||
|
local_metadata["tags"] = model_data["tags"]
|
||||||
|
|
||||||
|
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
|
||||||
|
"creator"
|
||||||
|
):
|
||||||
|
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||||
|
|
||||||
|
local_metadata["base_model"] = determine_base_model(
|
||||||
|
civitai_metadata.get("baseModel")
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._preview_service.ensure_preview_for_metadata(
|
||||||
|
metadata_path, local_metadata, civitai_metadata.get("images", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||||
|
return local_metadata
|
||||||
|
|
||||||
|
async def fetch_and_update_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sha256: str,
|
||||||
|
file_path: str,
|
||||||
|
model_data: Dict[str, Any],
|
||||||
|
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||||
|
) -> tuple[bool, Optional[str]]:
|
||||||
|
"""Fetch metadata for a model and update both disk and cache state."""
|
||||||
|
|
||||||
|
if not isinstance(model_data, dict):
|
||||||
|
error = f"Invalid model_data type: {type(model_data)}"
|
||||||
|
logger.error(error)
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_data.get("civitai_deleted") is True:
|
||||||
|
if not enable_archive or model_data.get("db_checked") is True:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"CivitAI model is deleted and metadata archive DB is not enabled",
|
||||||
|
)
|
||||||
|
metadata_provider = await self._get_provider("sqlite")
|
||||||
|
else:
|
||||||
|
metadata_provider = await self._get_default_provider()
|
||||||
|
|
||||||
|
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||||
|
if not civitai_metadata:
|
||||||
|
if error == "Model not found":
|
||||||
|
model_data["from_civitai"] = False
|
||||||
|
model_data["civitai_deleted"] = True
|
||||||
|
model_data["db_checked"] = enable_archive
|
||||||
|
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||||
|
|
||||||
|
data_to_save = model_data.copy()
|
||||||
|
data_to_save.pop("folder", None)
|
||||||
|
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||||
|
|
||||||
|
error_msg = (
|
||||||
|
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
model_data["from_civitai"] = True
|
||||||
|
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
|
||||||
|
model_data["db_checked"] = enable_archive
|
||||||
|
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||||
|
|
||||||
|
local_metadata = model_data.copy()
|
||||||
|
local_metadata.pop("folder", None)
|
||||||
|
|
||||||
|
await self.update_model_metadata(
|
||||||
|
metadata_path,
|
||||||
|
local_metadata,
|
||||||
|
civitai_metadata,
|
||||||
|
metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_payload = {
|
||||||
|
"model_name": local_metadata.get("model_name"),
|
||||||
|
"preview_url": local_metadata.get("preview_url"),
|
||||||
|
"civitai": local_metadata.get("civitai"),
|
||||||
|
}
|
||||||
|
model_data.update(update_payload)
|
||||||
|
|
||||||
|
await update_cache_func(file_path, file_path, local_metadata)
|
||||||
|
return True, None
|
||||||
|
except KeyError as exc:
|
||||||
|
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
except Exception as exc: # pragma: no cover - error path
|
||||||
|
error_msg = f"Error fetching metadata: {exc}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
async def fetch_metadata_by_sha(
|
||||||
|
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
|
||||||
|
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
|
"""Fetch metadata for a SHA256 hash from the configured provider."""
|
||||||
|
|
||||||
|
provider = metadata_provider or await self._get_default_provider()
|
||||||
|
return await provider.get_model_by_hash(sha256)
|
||||||
|
|
||||||
|
async def relink_metadata(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
metadata: Dict[str, Any],
|
||||||
|
model_id: int,
|
||||||
|
model_version_id: Optional[int],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Relink a local metadata record to a specific CivitAI model version."""
|
||||||
|
|
||||||
|
provider = await self._get_default_provider()
|
||||||
|
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
|
||||||
|
if not civitai_metadata:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model version not found on CivitAI for ID: {model_id}"
|
||||||
|
+ (f" with version: {model_version_id}" if model_version_id else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
primary_model_file: Optional[Dict[str, Any]] = None
|
||||||
|
for file_info in civitai_metadata.get("files", []):
|
||||||
|
if file_info.get("primary", False) and file_info.get("type") == "Model":
|
||||||
|
primary_model_file = file_info
|
||||||
|
break
|
||||||
|
|
||||||
|
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
|
||||||
|
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
await self.update_model_metadata(
|
||||||
|
metadata_path,
|
||||||
|
metadata,
|
||||||
|
civitai_metadata,
|
||||||
|
provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
async def save_metadata_updates(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
updates: Dict[str, Any],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||||
|
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Apply metadata updates and persist to disk and cache."""
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
|
||||||
|
for key, value in updates.items():
|
||||||
|
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
|
||||||
|
metadata[key].update(value)
|
||||||
|
else:
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
|
await update_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
|
if "model_name" in updates:
|
||||||
|
logger.debug("Metadata update touched model_name; cache resort required")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
async def verify_duplicate_hashes(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_paths: Iterable[str],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||||
|
hash_calculator: Callable[[str], Awaitable[str]],
|
||||||
|
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Verify a collection of files share the same SHA256 hash."""
|
||||||
|
|
||||||
|
file_paths = list(file_paths)
|
||||||
|
if not file_paths:
|
||||||
|
raise ValueError("No file paths provided for verification")
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"verified_as_duplicates": True,
|
||||||
|
"mismatched_files": [],
|
||||||
|
"new_hash_map": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_hash: Optional[str] = None
|
||||||
|
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
|
||||||
|
first_metadata = await metadata_loader(first_metadata_path)
|
||||||
|
if first_metadata and "sha256" in first_metadata:
|
||||||
|
expected_hash = first_metadata["sha256"].lower()
|
||||||
|
|
||||||
|
for path in file_paths:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
actual_hash = await hash_calculator(path)
|
||||||
|
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
stored_hash = metadata.get("sha256", "").lower()
|
||||||
|
|
||||||
|
if not expected_hash:
|
||||||
|
expected_hash = stored_hash
|
||||||
|
|
||||||
|
if actual_hash != expected_hash:
|
||||||
|
results["verified_as_duplicates"] = False
|
||||||
|
results["mismatched_files"].append(path)
|
||||||
|
results["new_hash_map"][path] = actual_hash
|
||||||
|
|
||||||
|
if actual_hash != stored_hash:
|
||||||
|
metadata["sha256"] = actual_hash
|
||||||
|
await self._metadata_manager.save_metadata(path, metadata)
|
||||||
|
await update_cache(path, path, metadata)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.error("Error verifying hash for %s: %s", path, exc)
|
||||||
|
results["mismatched_files"].append(path)
|
||||||
|
results["new_hash_map"][path] = "error_calculating_hash"
|
||||||
|
results["verified_as_duplicates"] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
168
py/services/preview_asset_service.py
Normal file
168
py/services/preview_asset_service.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Service for processing preview assets for models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||||
|
|
||||||
|
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewAssetService:
|
||||||
|
"""Manage fetching and persisting preview assets."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata_manager,
|
||||||
|
downloader_factory: Callable[[], Awaitable],
|
||||||
|
exif_utils,
|
||||||
|
) -> None:
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
self._downloader_factory = downloader_factory
|
||||||
|
self._exif_utils = exif_utils
|
||||||
|
|
||||||
|
async def ensure_preview_for_metadata(
|
||||||
|
self,
|
||||||
|
metadata_path: str,
|
||||||
|
local_metadata: Dict[str, object],
|
||||||
|
images: Sequence[Dict[str, object]] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure preview assets exist for the supplied metadata entry."""
|
||||||
|
|
||||||
|
if local_metadata.get("preview_url") and os.path.exists(
|
||||||
|
str(local_metadata["preview_url"])
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not images:
|
||||||
|
return
|
||||||
|
|
||||||
|
first_preview = images[0]
|
||||||
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||||
|
preview_dir = os.path.dirname(metadata_path)
|
||||||
|
is_video = first_preview.get("type") == "video"
|
||||||
|
|
||||||
|
if is_video:
|
||||||
|
extension = ".mp4"
|
||||||
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
first_preview["url"], preview_path, use_auth=False
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
else:
|
||||||
|
extension = ".webp"
|
||||||
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
success, content, _headers = await downloader.download_to_memory(
|
||||||
|
first_preview["url"], use_auth=False
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
optimized_data, _ = self._exif_utils.optimize_image(
|
||||||
|
image_data=content,
|
||||||
|
target_width=CARD_PREVIEW_WIDTH,
|
||||||
|
format="webp",
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=False,
|
||||||
|
)
|
||||||
|
with open(preview_path, "wb") as handle:
|
||||||
|
handle.write(optimized_data)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.error("Error optimizing preview image: %s", exc)
|
||||||
|
try:
|
||||||
|
with open(preview_path, "wb") as handle:
|
||||||
|
handle.write(content)
|
||||||
|
except Exception as save_exc:
|
||||||
|
logger.error("Error saving preview image: %s", save_exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
|
||||||
|
async def replace_preview(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_path: str,
|
||||||
|
preview_data: bytes,
|
||||||
|
content_type: str,
|
||||||
|
original_filename: Optional[str],
|
||||||
|
nsfw_level: int,
|
||||||
|
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""Replace an existing preview asset for a model."""
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||||
|
folder = os.path.dirname(model_path)
|
||||||
|
|
||||||
|
extension, optimized_data = await self._convert_preview(
|
||||||
|
preview_data, content_type, original_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
|
existing_preview = os.path.join(folder, base_name + ext)
|
||||||
|
if os.path.exists(existing_preview):
|
||||||
|
try:
|
||||||
|
os.remove(existing_preview)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.warning(
|
||||||
|
"Failed to delete existing preview %s: %s", existing_preview, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
|
||||||
|
with open(preview_path, "wb") as handle:
|
||||||
|
handle.write(optimized_data)
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
metadata["preview_url"] = preview_path
|
||||||
|
metadata["preview_nsfw_level"] = nsfw_level
|
||||||
|
await self._metadata_manager.save_metadata(model_path, metadata)
|
||||||
|
|
||||||
|
await update_preview_in_cache(model_path, preview_path, nsfw_level)
|
||||||
|
|
||||||
|
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
|
||||||
|
|
||||||
|
async def _convert_preview(
|
||||||
|
self, data: bytes, content_type: str, original_filename: Optional[str]
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
"""Convert preview bytes to the persisted representation."""
|
||||||
|
|
||||||
|
if content_type.startswith("video/"):
|
||||||
|
extension = self._resolve_video_extension(content_type, original_filename)
|
||||||
|
return extension, data
|
||||||
|
|
||||||
|
original_ext = (original_filename or "").lower()
|
||||||
|
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
|
||||||
|
return ".gif", data
|
||||||
|
|
||||||
|
optimized_data, _ = self._exif_utils.optimize_image(
|
||||||
|
image_data=data,
|
||||||
|
target_width=CARD_PREVIEW_WIDTH,
|
||||||
|
format="webp",
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=False,
|
||||||
|
)
|
||||||
|
return ".webp", optimized_data
|
||||||
|
|
||||||
|
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
|
||||||
|
"""Infer the best extension for a video preview."""
|
||||||
|
|
||||||
|
if original_filename:
|
||||||
|
extension = os.path.splitext(original_filename)[1].lower()
|
||||||
|
if extension in {".mp4", ".webm", ".mov", ".avi"}:
|
||||||
|
return extension
|
||||||
|
|
||||||
|
if "webm" in content_type:
|
||||||
|
return ".webm"
|
||||||
|
return ".mp4"
|
||||||
|
|
||||||
47
py/services/tag_update_service.py
Normal file
47
py/services/tag_update_service.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Service for updating tag collections on metadata records."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Awaitable, Callable, Dict, List, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class TagUpdateService:
|
||||||
|
"""Encapsulate tag manipulation for models."""
|
||||||
|
|
||||||
|
def __init__(self, *, metadata_manager) -> None:
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
|
||||||
|
async def add_tags(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
new_tags: Sequence[str],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
|
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
|
||||||
|
|
||||||
|
base, _ = os.path.splitext(file_path)
|
||||||
|
metadata_path = f"{base}.metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
|
||||||
|
existing_tags = list(metadata.get("tags", []))
|
||||||
|
existing_lower = [tag.lower() for tag in existing_tags]
|
||||||
|
|
||||||
|
tags_added: List[str] = []
|
||||||
|
for tag in new_tags:
|
||||||
|
if isinstance(tag, str) and tag.strip():
|
||||||
|
normalized = tag.strip()
|
||||||
|
if normalized.lower() not in existing_lower:
|
||||||
|
existing_tags.append(normalized)
|
||||||
|
existing_lower.append(normalized.lower())
|
||||||
|
tags_added.append(normalized)
|
||||||
|
|
||||||
|
metadata["tags"] = existing_tags
|
||||||
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
|
await update_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
|
return existing_tags
|
||||||
|
|
||||||
@@ -1,14 +1,34 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from ..utils.metadata_manager import MetadataManager
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..recipes.constants import GEN_PARAM_KEYS
|
||||||
|
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
|
from ..services.metadata_sync_service import MetadataSyncService
|
||||||
|
from ..services.preview_asset_service import PreviewAssetService
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
from ..services.downloader import get_downloader
|
||||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..recipes.constants import GEN_PARAM_KEYS
|
from ..utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_preview_service = PreviewAssetService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
downloader_factory=get_downloader,
|
||||||
|
exif_utils=ExifUtils,
|
||||||
|
)
|
||||||
|
|
||||||
|
_metadata_sync_service = MetadataSyncService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
preview_service=_preview_service,
|
||||||
|
settings=settings,
|
||||||
|
default_metadata_provider_factory=get_default_metadata_provider,
|
||||||
|
metadata_provider_selector=get_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataUpdater:
|
class MetadataUpdater:
|
||||||
"""Handles updating model metadata related to example images"""
|
"""Handles updating model metadata related to example images"""
|
||||||
|
|
||||||
@@ -53,11 +73,11 @@ class MetadataUpdater:
|
|||||||
async def update_cache_func(old_path, new_path, metadata):
|
async def update_cache_func(old_path, new_path, metadata):
|
||||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||||
|
|
||||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
success, error = await _metadata_sync_service.fetch_and_update_model(
|
||||||
model_hash,
|
sha256=model_hash,
|
||||||
file_path,
|
file_path=file_path,
|
||||||
model_data,
|
model_data=model_data,
|
||||||
update_cache_func
|
update_cache_func=update_cache_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
|
|||||||
@@ -4,5 +4,8 @@ testpaths = tests
|
|||||||
python_files = test_*.py
|
python_files = test_*.py
|
||||||
python_classes = Test*
|
python_classes = Test*
|
||||||
python_functions = test_*
|
python_functions = test_*
|
||||||
|
# Register async marker for coroutine-style tests
|
||||||
|
markers =
|
||||||
|
asyncio: execute test within asyncio event loop
|
||||||
# Skip problematic directories to avoid import conflicts
|
# Skip problematic directories to avoid import conflicts
|
||||||
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
import types
|
import types
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -39,6 +41,13 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
|
|||||||
sys.modules['nodes'] = nodes_mock
|
sys.modules['nodes'] = nodes_mock
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_pyfunc_call(pyfuncitem):
|
||||||
|
if inspect.iscoroutinefunction(pyfuncitem.function):
|
||||||
|
asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs))
|
||||||
|
return True
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockHashIndex:
|
class MockHashIndex:
|
||||||
"""Minimal hash index stub mirroring the scanner contract."""
|
"""Minimal hash index stub mirroring the scanner contract."""
|
||||||
|
|||||||
@@ -1,8 +1,35 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.base_model_service import BaseModelService
|
import importlib
|
||||||
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
|
import importlib.util
|
||||||
from py.utils.models import BaseModelMetadata
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
|
||||||
|
def import_from(module_name: str):
|
||||||
|
existing = sys.modules.get("py")
|
||||||
|
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||||
|
sys.modules.pop("py", None)
|
||||||
|
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec and spec.loader
|
||||||
|
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||||
|
module.__path__ = [str(ROOT / "py")]
|
||||||
|
sys.modules["py"] = module
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
|
||||||
|
|
||||||
|
BaseModelService = import_from("py.services.base_model_service").BaseModelService
|
||||||
|
model_query_module = import_from("py.services.model_query")
|
||||||
|
ModelCacheRepository = model_query_module.ModelCacheRepository
|
||||||
|
ModelFilterSet = model_query_module.ModelFilterSet
|
||||||
|
SearchStrategy = model_query_module.SearchStrategy
|
||||||
|
SortParams = model_query_module.SortParams
|
||||||
|
BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata
|
||||||
|
|
||||||
|
|
||||||
class StubSettings:
|
class StubSettings:
|
||||||
|
|||||||
273
tests/services/test_route_support_services.py
Normal file
273
tests/services/test_route_support_services.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def import_from(module_name: str):
|
||||||
|
existing = sys.modules.get("py")
|
||||||
|
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||||
|
sys.modules.pop("py", None)
|
||||||
|
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec and spec.loader
|
||||||
|
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||||
|
module.__path__ = [str(ROOT / "py")]
|
||||||
|
sys.modules["py"] = module
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
|
||||||
|
|
||||||
|
DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator
|
||||||
|
MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService
|
||||||
|
PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService
|
||||||
|
TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService
|
||||||
|
|
||||||
|
|
||||||
|
class DummySettings:
|
||||||
|
def __init__(self, values: Dict[str, Any] | None = None) -> None:
|
||||||
|
self._values = values or {}
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
return self._values.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingMetadataManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.saved: List[tuple[str, Dict[str, Any]]] = []
|
||||||
|
|
||||||
|
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
|
||||||
|
self.saved.append((path, json.loads(json.dumps(metadata))))
|
||||||
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||||
|
Path(metadata_path).write_text(json.dumps(metadata))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingPreviewService:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
|
||||||
|
|
||||||
|
async def ensure_preview_for_metadata(
|
||||||
|
self, metadata_path: str, local_metadata: Dict[str, Any], images
|
||||||
|
) -> None:
|
||||||
|
self.calls.append((metadata_path, list(images or [])))
|
||||||
|
local_metadata["preview_url"] = "preview.webp"
|
||||||
|
local_metadata["preview_nsfw_level"] = 1
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProvider:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self.payload = payload
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, sha256: str):
|
||||||
|
return self.payload, None
|
||||||
|
|
||||||
|
async def get_model_version(self, model_id: int, model_version_id: int | None):
|
||||||
|
return self.payload
|
||||||
|
|
||||||
|
|
||||||
|
class FakeExifUtils:
|
||||||
|
@staticmethod
|
||||||
|
def optimize_image(**kwargs):
|
||||||
|
return kwargs["image_data"], {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
|
||||||
|
manager = RecordingMetadataManager()
|
||||||
|
preview = RecordingPreviewService()
|
||||||
|
provider = DummyProvider({
|
||||||
|
"baseModel": "SD15",
|
||||||
|
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
|
||||||
|
"trainedWords": ["word"],
|
||||||
|
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
|
||||||
|
})
|
||||||
|
|
||||||
|
service = MetadataSyncService(
|
||||||
|
metadata_manager=manager,
|
||||||
|
preview_service=preview,
|
||||||
|
settings=DummySettings(),
|
||||||
|
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||||
|
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_path = str(tmp_path / "model.metadata.json")
|
||||||
|
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
|
||||||
|
|
||||||
|
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
|
||||||
|
|
||||||
|
assert updated["model_name"] == "Merged"
|
||||||
|
assert updated["modelDescription"] == "desc"
|
||||||
|
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
|
||||||
|
assert manager.saved
|
||||||
|
assert preview.calls
|
||||||
|
|
||||||
|
|
||||||
|
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
|
||||||
|
manager = RecordingMetadataManager()
|
||||||
|
preview = RecordingPreviewService()
|
||||||
|
provider = DummyProvider({
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"model": {"name": "Updated"},
|
||||||
|
"images": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
update_cache_calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||||
|
update_cache_calls.append({"original": original, "metadata": metadata})
|
||||||
|
return True
|
||||||
|
|
||||||
|
service = MetadataSyncService(
|
||||||
|
metadata_manager=manager,
|
||||||
|
preview_service=preview,
|
||||||
|
settings=DummySettings(),
|
||||||
|
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||||
|
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
|
||||||
|
success, error = asyncio.run(
|
||||||
|
service.fetch_and_update_model(
|
||||||
|
sha256="abc",
|
||||||
|
file_path=str(tmp_path / "model.safetensors"),
|
||||||
|
model_data=model_data,
|
||||||
|
update_cache_func=update_cache,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert error is None
|
||||||
|
assert update_cache_calls
|
||||||
|
assert manager.saved
|
||||||
|
|
||||||
|
|
||||||
|
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
|
||||||
|
metadata_path = tmp_path / "sample.metadata.json"
|
||||||
|
metadata_path.write_text(json.dumps({}))
|
||||||
|
|
||||||
|
async def metadata_loader(path: str) -> Dict[str, Any]:
|
||||||
|
return json.loads(Path(path).read_text())
|
||||||
|
|
||||||
|
manager = RecordingMetadataManager()
|
||||||
|
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=manager,
|
||||||
|
downloader_factory=lambda: asyncio.sleep(0, result=None),
|
||||||
|
exif_utils=FakeExifUtils(),
|
||||||
|
)
|
||||||
|
|
||||||
|
preview_calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
|
||||||
|
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
|
||||||
|
return True
|
||||||
|
|
||||||
|
model_path = str(tmp_path / "sample.safetensors")
|
||||||
|
Path(model_path).write_bytes(b"model")
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
service.replace_preview(
|
||||||
|
model_path=model_path,
|
||||||
|
preview_data=b"image-bytes",
|
||||||
|
content_type="image/png",
|
||||||
|
original_filename="preview.png",
|
||||||
|
nsfw_level=2,
|
||||||
|
update_preview_in_cache=update_preview,
|
||||||
|
metadata_loader=metadata_loader,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["preview_nsfw_level"] == 2
|
||||||
|
assert preview_calls
|
||||||
|
saved_metadata = json.loads(metadata_path.read_text())
|
||||||
|
assert saved_metadata["preview_nsfw_level"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_coordinator_emits_progress() -> None:
|
||||||
|
class WSStub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.progress_events: List[Dict[str, Any]] = []
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def generate_download_id(self) -> str:
|
||||||
|
self.counter += 1
|
||||||
|
return f"dl-{self.counter}"
|
||||||
|
|
||||||
|
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
||||||
|
self.progress_events.append({"id": download_id, **payload})
|
||||||
|
|
||||||
|
class DownloadManagerStub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
||||||
|
self.calls.append(kwargs)
|
||||||
|
await kwargs["progress_callback"](10)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
return {"success": True, "download_id": download_id}
|
||||||
|
|
||||||
|
async def get_active_downloads(self) -> Dict[str, Any]:
|
||||||
|
return {"active": []}
|
||||||
|
|
||||||
|
ws_stub = WSStub()
|
||||||
|
manager_stub = DownloadManagerStub()
|
||||||
|
|
||||||
|
coordinator = DownloadCoordinator(
|
||||||
|
ws_manager=ws_stub,
|
||||||
|
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert manager_stub.calls
|
||||||
|
assert ws_stub.progress_events
|
||||||
|
|
||||||
|
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
||||||
|
assert cancel_result["success"] is True
|
||||||
|
|
||||||
|
active = asyncio.run(coordinator.list_active_downloads())
|
||||||
|
assert active == {"active": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
|
||||||
|
|
||||||
|
async def loader(path: str) -> Dict[str, Any]:
|
||||||
|
return json.loads(Path(path).read_text())
|
||||||
|
|
||||||
|
manager = RecordingMetadataManager()
|
||||||
|
|
||||||
|
service = TagUpdateService(metadata_manager=manager)
|
||||||
|
|
||||||
|
cache_updates: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||||
|
cache_updates.append(metadata)
|
||||||
|
return True
|
||||||
|
|
||||||
|
tags = asyncio.run(
|
||||||
|
service.add_tags(
|
||||||
|
file_path=str(tmp_path / "model.safetensors"),
|
||||||
|
new_tags=["New", "existing"],
|
||||||
|
metadata_loader=loader,
|
||||||
|
update_cache=update_cache,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tags == ["Existing", "New"]
|
||||||
|
assert manager.saved
|
||||||
|
assert cache_updates
|
||||||
Reference in New Issue
Block a user