mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
Add name pattern filtering to LoRA Pool node allowing users to filter LoRAs by filename or model name using either plain text or regex patterns. Features: - Include patterns: only show LoRAs matching at least one pattern - Exclude patterns: exclude LoRAs matching any pattern - Regex toggle: switch between substring and regex matching - Case-insensitive matching for both modes - Invalid regex automatically falls back to substring matching - Filters apply to both file_name and model_name fields Backend: - Update LoraPoolLM._default_config() with namePatterns structure - Add name pattern filtering to _apply_pool_filters() and _apply_specific_filters() - Add API parameter parsing for name_pattern_include/exclude/use_regex - Update LoraPoolConfig type with namePatterns field Frontend: - Add NamePatternsSection.vue component with pattern input UI - Update useLoraPoolState to manage pattern state and API integration - Update LoraPoolSummaryView to display NamePatternsSection - Increase LORA_POOL_WIDGET_MIN_HEIGHT to accommodate new UI Tests: - Add 7 test cases covering text/regex include, exclude, combined filtering, model name fallback, and invalid regex handling Closes #839
2443 lines
98 KiB
Python
2443 lines
98 KiB
Python
"""Handlers for base model routes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
|
|
|
|
from aiohttp import web
|
|
import jinja2
|
|
|
|
from ...config import config
|
|
from ...services.download_coordinator import DownloadCoordinator
|
|
from ...services.metadata_sync_service import MetadataSyncService
|
|
from ...services.model_file_service import ModelMoveService
|
|
from ...services.preview_asset_service import PreviewAssetService
|
|
from ...services.settings_manager import SettingsManager, get_settings_manager
|
|
from ...services.tag_update_service import TagUpdateService
|
|
from ...services.use_cases import (
|
|
AutoOrganizeInProgressError,
|
|
AutoOrganizeUseCase,
|
|
BulkMetadataRefreshUseCase,
|
|
DownloadModelEarlyAccessError,
|
|
DownloadModelUseCase,
|
|
DownloadModelValidationError,
|
|
MetadataRefreshProgressReporter,
|
|
)
|
|
from ...services.websocket_manager import WebSocketManager
|
|
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
|
from ...services.errors import RateLimitError, ResourceNotFoundError
|
|
from ...utils.civitai_utils import resolve_license_payload
|
|
from ...utils.file_utils import calculate_sha256
|
|
from ...utils.metadata_manager import MetadataManager
|
|
|
|
LICENSE_FIELDS = (
|
|
"allowNoCredit",
|
|
"allowCommercialUse",
|
|
"allowDerivatives",
|
|
"allowDifferentLicense",
|
|
)
|
|
|
|
|
|
class ModelPageView:
|
|
"""Render the HTML view for model listings."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
template_env: jinja2.Environment,
|
|
template_name: str,
|
|
service,
|
|
settings_service: SettingsManager,
|
|
server_i18n,
|
|
logger: logging.Logger,
|
|
) -> None:
|
|
self._template_env = template_env
|
|
self._template_name = template_name
|
|
self._service = service
|
|
self._settings = settings_service
|
|
self._server_i18n = server_i18n
|
|
self._logger = logger
|
|
self._app_version = self._get_app_version()
|
|
|
|
def _load_supporters(self) -> dict:
|
|
"""Load supporters data from JSON file."""
|
|
try:
|
|
current_file = os.path.abspath(__file__)
|
|
root_dir = os.path.dirname(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
|
)
|
|
supporters_path = os.path.join(root_dir, "data", "supporters.json")
|
|
|
|
if os.path.exists(supporters_path):
|
|
with open(supporters_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
self._logger.debug(f"Failed to load supporters data: {e}")
|
|
|
|
return {"specialThanks": [], "allSupporters": [], "totalCount": 0}
|
|
|
|
def _get_app_version(self) -> str:
|
|
version = "1.0.0"
|
|
short_hash = "stable"
|
|
try:
|
|
import toml
|
|
|
|
current_file = os.path.abspath(__file__)
|
|
# Navigate up from py/routes/handlers/model_handlers.py to project root
|
|
root_dir = os.path.dirname(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
|
)
|
|
pyproject_path = os.path.join(root_dir, "pyproject.toml")
|
|
|
|
if os.path.exists(pyproject_path):
|
|
with open(pyproject_path, "r", encoding="utf-8") as f:
|
|
data = toml.load(f)
|
|
version = (
|
|
data.get("project", {}).get("version", "1.0.0").replace("v", "")
|
|
)
|
|
|
|
# Try to get git info for granular cache busting
|
|
git_dir = os.path.join(root_dir, ".git")
|
|
if os.path.exists(git_dir):
|
|
try:
|
|
import git
|
|
|
|
repo = git.Repo(root_dir)
|
|
short_hash = repo.head.commit.hexsha[:7]
|
|
except Exception:
|
|
# Fallback if git is not available or not a repo
|
|
pass
|
|
except Exception as e:
|
|
self._logger.debug(f"Failed to read version info for cache busting: {e}")
|
|
|
|
return f"{version}-{short_hash}"
|
|
|
|
async def handle(self, request: web.Request) -> web.Response:
|
|
try:
|
|
is_initializing = (
|
|
self._service.scanner._cache is None
|
|
or (
|
|
hasattr(self._service.scanner, "is_initializing")
|
|
and callable(self._service.scanner.is_initializing)
|
|
and self._service.scanner.is_initializing()
|
|
)
|
|
or (
|
|
hasattr(self._service.scanner, "_is_initializing")
|
|
and self._service.scanner._is_initializing
|
|
)
|
|
)
|
|
|
|
if not self._template_env or not self._template_name:
|
|
return web.Response(
|
|
text="Template environment or template name not set",
|
|
status=500,
|
|
)
|
|
|
|
user_language = self._settings.get("language", "en")
|
|
self._server_i18n.set_locale(user_language)
|
|
|
|
if not hasattr(self._template_env, "_i18n_filter_added"):
|
|
self._template_env.filters["t"] = (
|
|
self._server_i18n.create_template_filter()
|
|
)
|
|
self._template_env._i18n_filter_added = True # type: ignore[attr-defined]
|
|
|
|
template_context = {
|
|
"is_initializing": is_initializing,
|
|
"settings": self._settings,
|
|
"request": request,
|
|
"folders": [],
|
|
"t": self._server_i18n.get_translation,
|
|
"version": self._app_version,
|
|
}
|
|
|
|
if not is_initializing:
|
|
try:
|
|
cache = await self._service.scanner.get_cached_data(
|
|
force_refresh=False
|
|
)
|
|
template_context["folders"] = getattr(cache, "folders", [])
|
|
except Exception as cache_error: # pragma: no cover - logging path
|
|
self._logger.error("Error loading cache data: %s", cache_error)
|
|
template_context["is_initializing"] = True
|
|
|
|
rendered = self._template_env.get_template(self._template_name).render(
|
|
**template_context
|
|
)
|
|
return web.Response(text=rendered, content_type="text/html")
|
|
except Exception as exc: # pragma: no cover - logging path
|
|
self._logger.error("Error handling models page: %s", exc, exc_info=True)
|
|
return web.Response(text="Error loading models page", status=500)
|
|
|
|
|
|
class ModelListingHandler:
|
|
"""Provide paginated model listings."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
service,
|
|
parse_specific_params: Callable[[web.Request], Dict],
|
|
logger: logging.Logger,
|
|
) -> None:
|
|
self._service = service
|
|
self._parse_specific_params = parse_specific_params
|
|
self._logger = logger
|
|
|
|
async def get_models(self, request: web.Request) -> web.Response:
|
|
start_time = time.perf_counter()
|
|
try:
|
|
params = self._parse_common_params(request)
|
|
result = await self._service.get_paginated_data(**params)
|
|
|
|
format_start = time.perf_counter()
|
|
formatted_result = {
|
|
"items": [
|
|
await self._service.format_response(item)
|
|
for item in result["items"]
|
|
],
|
|
"total": result["total"],
|
|
"page": result["page"],
|
|
"page_size": result["page_size"],
|
|
"total_pages": result["total_pages"],
|
|
}
|
|
format_duration = time.perf_counter() - format_start
|
|
|
|
duration = time.perf_counter() - start_time
|
|
self._logger.debug(
|
|
"Request for %s/list took %.3fs (formatting: %.3fs)",
|
|
self._service.model_type,
|
|
duration,
|
|
format_duration,
|
|
)
|
|
return web.json_response(formatted_result)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error retrieving %ss: %s", self._service.model_type, exc, exc_info=True
|
|
)
|
|
return web.json_response({"error": str(exc)}, status=500)
|
|
|
|
def _parse_common_params(self, request: web.Request) -> Dict:
|
|
page = int(request.query.get("page", "1"))
|
|
page_size = min(int(request.query.get("page_size", "20")), 100)
|
|
sort_by = request.query.get("sort_by", "name")
|
|
folder = request.query.get("folder")
|
|
folder_include = list(request.query.getall("folder_include", []))
|
|
search = request.query.get("search")
|
|
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
|
|
|
|
base_models = request.query.getall("base_model", [])
|
|
folder_exclude = list(request.query.getall("folder_exclude", []))
|
|
model_types = list(request.query.getall("model_type", []))
|
|
model_types.extend(request.query.getall("civitai_model_type", []))
|
|
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
|
|
legacy_tags = request.query.getall("tag", [])
|
|
if not legacy_tags:
|
|
legacy_csv = request.query.get("tags")
|
|
if legacy_csv:
|
|
legacy_tags = [
|
|
tag.strip() for tag in legacy_csv.split(",") if tag.strip()
|
|
]
|
|
|
|
include_tags = request.query.getall("tag_include", [])
|
|
exclude_tags = request.query.getall("tag_exclude", [])
|
|
|
|
tag_filters: Dict[str, str] = {}
|
|
for tag in legacy_tags:
|
|
if tag:
|
|
tag_filters[tag] = "include"
|
|
|
|
for tag in include_tags:
|
|
if tag:
|
|
tag_filters[tag] = "include"
|
|
|
|
for tag in exclude_tags:
|
|
if tag:
|
|
tag_filters[tag] = "exclude"
|
|
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
|
|
|
|
search_options = {
|
|
"filename": request.query.get("search_filename", "true").lower() == "true",
|
|
"modelname": request.query.get("search_modelname", "true").lower()
|
|
== "true",
|
|
"tags": request.query.get("search_tags", "false").lower() == "true",
|
|
"creator": request.query.get("search_creator", "false").lower() == "true",
|
|
"recursive": request.query.get("recursive", "true").lower() == "true",
|
|
}
|
|
|
|
hash_filters: Dict[str, object] = {}
|
|
if "hash" in request.query:
|
|
hash_filters["single_hash"] = request.query["hash"]
|
|
elif "hashes" in request.query:
|
|
try:
|
|
hash_list = json.loads(request.query["hashes"])
|
|
if isinstance(hash_list, list):
|
|
hash_filters["multiple_hashes"] = hash_list
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
update_available_only = (
|
|
request.query.get("update_available_only", "false").lower() == "true"
|
|
)
|
|
|
|
# Tag logic: "any" (OR) or "all" (AND) for include tags
|
|
tag_logic = request.query.get("tag_logic", "any").lower()
|
|
if tag_logic not in ("any", "all"):
|
|
tag_logic = "any"
|
|
|
|
# New license-based query filters
|
|
credit_required = request.query.get("credit_required")
|
|
if credit_required is not None:
|
|
credit_required = credit_required.lower() not in ("false", "0", "")
|
|
else:
|
|
credit_required = None # None means no filter applied
|
|
|
|
allow_selling_generated_content = request.query.get(
|
|
"allow_selling_generated_content"
|
|
)
|
|
if allow_selling_generated_content is not None:
|
|
allow_selling_generated_content = (
|
|
allow_selling_generated_content.lower() not in ("false", "0", "")
|
|
)
|
|
else:
|
|
allow_selling_generated_content = None # None means no filter applied
|
|
|
|
# Name pattern filters for LoRA Pool
|
|
name_pattern_include = request.query.getall("name_pattern_include", [])
|
|
name_pattern_exclude = request.query.getall("name_pattern_exclude", [])
|
|
name_pattern_use_regex = (
|
|
request.query.get("name_pattern_use_regex", "false").lower() == "true"
|
|
)
|
|
|
|
return {
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"sort_by": sort_by,
|
|
"folder": folder,
|
|
"folder_include": folder_include,
|
|
"folder_exclude": folder_exclude,
|
|
"search": search,
|
|
"fuzzy_search": fuzzy_search,
|
|
"base_models": base_models,
|
|
"tags": tag_filters,
|
|
"tag_logic": tag_logic,
|
|
"search_options": search_options,
|
|
"hash_filters": hash_filters,
|
|
"favorites_only": favorites_only,
|
|
"update_available_only": update_available_only,
|
|
"credit_required": credit_required,
|
|
"allow_selling_generated_content": allow_selling_generated_content,
|
|
"model_types": model_types,
|
|
"name_pattern_include": name_pattern_include,
|
|
"name_pattern_exclude": name_pattern_exclude,
|
|
"name_pattern_use_regex": name_pattern_use_regex,
|
|
**self._parse_specific_params(request),
|
|
}
|
|
|
|
|
|
class ModelManagementHandler:
|
|
"""Handle mutation operations on models."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
service,
|
|
logger: logging.Logger,
|
|
metadata_sync: MetadataSyncService,
|
|
preview_service: PreviewAssetService,
|
|
tag_update_service: TagUpdateService,
|
|
lifecycle_service,
|
|
) -> None:
|
|
self._service = service
|
|
self._logger = logger
|
|
self._metadata_sync = metadata_sync
|
|
self._preview_service = preview_service
|
|
self._tag_update_service = tag_update_service
|
|
self._lifecycle_service = lifecycle_service
|
|
|
|
async def delete_model(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
if not file_path:
|
|
return web.Response(text="Model path is required", status=400)
|
|
|
|
result = await self._lifecycle_service.delete_model(file_path)
|
|
return web.json_response(result)
|
|
except ValueError as exc:
|
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
|
except Exception as exc:
|
|
self._logger.error("Error deleting model: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def exclude_model(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
if not file_path:
|
|
return web.Response(text="Model path is required", status=400)
|
|
|
|
result = await self._lifecycle_service.exclude_model(file_path)
|
|
return web.json_response(result)
|
|
except ValueError as exc:
|
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
|
except Exception as exc:
|
|
self._logger.error("Error excluding model: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
if not file_path:
|
|
return web.json_response(
|
|
{"success": False, "error": "File path is required"}, status=400
|
|
)
|
|
|
|
cache = await self._service.scanner.get_cached_data()
|
|
model_data = next(
|
|
(item for item in cache.raw_data if item["file_path"] == file_path),
|
|
None,
|
|
)
|
|
if not model_data:
|
|
return web.json_response(
|
|
{"success": False, "error": "Model not found in cache"}, status=404
|
|
)
|
|
|
|
# Check if hash needs to be calculated (lazy hash for checkpoints)
|
|
sha256 = model_data.get("sha256")
|
|
hash_status = model_data.get("hash_status", "completed")
|
|
|
|
if not sha256 or hash_status != "completed":
|
|
# For checkpoints, calculate hash on-demand
|
|
scanner = self._service.scanner
|
|
if hasattr(scanner, "calculate_hash_for_model"):
|
|
self._logger.info(
|
|
f"Lazy hash calculation triggered for {file_path}"
|
|
)
|
|
sha256 = await scanner.calculate_hash_for_model(file_path)
|
|
if not sha256:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "Failed to calculate SHA256 hash",
|
|
},
|
|
status=500,
|
|
)
|
|
# Update model_data with new hash
|
|
model_data["sha256"] = sha256
|
|
model_data["hash_status"] = "completed"
|
|
else:
|
|
return web.json_response(
|
|
{"success": False, "error": "No SHA256 hash found"}, status=400
|
|
)
|
|
|
|
await MetadataManager.hydrate_model_data(model_data)
|
|
|
|
success, error = await self._metadata_sync.fetch_and_update_model(
|
|
sha256=model_data["sha256"],
|
|
file_path=file_path,
|
|
model_data=model_data,
|
|
update_cache_func=self._service.scanner.update_single_model_cache,
|
|
)
|
|
if not success:
|
|
return web.json_response({"success": False, "error": error})
|
|
|
|
formatted_metadata = await self._service.format_response(model_data)
|
|
return web.json_response({"success": True, "metadata": formatted_metadata})
|
|
except Exception as exc:
|
|
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def relink_civitai(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
model_id = data.get("model_id")
|
|
model_version_id = data.get("model_version_id")
|
|
|
|
if not file_path or model_id is None:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "Both file_path and model_id are required",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
|
local_metadata = await self._metadata_sync.load_local_metadata(
|
|
metadata_path
|
|
)
|
|
|
|
updated_metadata = await self._metadata_sync.relink_metadata(
|
|
file_path=file_path,
|
|
metadata=local_metadata,
|
|
model_id=int(model_id),
|
|
model_version_id=int(model_version_id) if model_version_id else None,
|
|
)
|
|
|
|
await self._service.scanner.update_single_model_cache(
|
|
file_path, file_path, updated_metadata
|
|
)
|
|
|
|
message = f"Model successfully re-linked to Civitai model {model_id}" + (
|
|
f" version {model_version_id}" if model_version_id else ""
|
|
)
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"message": message,
|
|
"hash": updated_metadata.get("sha256", ""),
|
|
}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def replace_preview(self, request: web.Request) -> web.Response:
|
|
try:
|
|
reader = await request.multipart()
|
|
|
|
field = await reader.next()
|
|
if field is None or field.name != "preview_file":
|
|
raise ValueError("Expected 'preview_file' field")
|
|
content_type = field.headers.get("Content-Type", "image/png")
|
|
content_disposition = field.headers.get("Content-Disposition", "")
|
|
|
|
original_filename = None
|
|
import re
|
|
|
|
match = re.search(r'filename="(.*?)"', content_disposition)
|
|
if match:
|
|
original_filename = match.group(1)
|
|
|
|
preview_data = await field.read()
|
|
|
|
field = await reader.next()
|
|
if field is None or field.name != "model_path":
|
|
raise ValueError("Expected 'model_path' field")
|
|
model_path = (await field.read()).decode()
|
|
|
|
nsfw_level = 0
|
|
field = await reader.next()
|
|
if field and field.name == "nsfw_level":
|
|
try:
|
|
nsfw_level = int((await field.read()).decode())
|
|
except (ValueError, TypeError):
|
|
self._logger.warning("Invalid NSFW level format, using default 0")
|
|
|
|
result = await self._preview_service.replace_preview(
|
|
model_path=model_path,
|
|
preview_data=preview_data,
|
|
content_type=content_type,
|
|
original_filename=original_filename,
|
|
nsfw_level=nsfw_level,
|
|
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"preview_url": config.get_preview_static_url(
|
|
result["preview_path"]
|
|
),
|
|
"preview_nsfw_level": result["preview_nsfw_level"],
|
|
}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def set_preview_from_url(self, request: web.Request) -> web.Response:
|
|
"""Set a preview image from a remote URL (e.g., CivitAI)."""
|
|
try:
|
|
from ...utils.civitai_utils import rewrite_preview_url
|
|
from ...services.downloader import get_downloader
|
|
|
|
data = await request.json()
|
|
model_path = data.get("model_path")
|
|
image_url = data.get("image_url")
|
|
nsfw_level = data.get("nsfw_level", 0)
|
|
|
|
if not model_path:
|
|
return web.json_response(
|
|
{"success": False, "error": "Model path is required"}, status=400
|
|
)
|
|
|
|
if not image_url:
|
|
return web.json_response(
|
|
{"success": False, "error": "Image URL is required"}, status=400
|
|
)
|
|
|
|
# Rewrite URL to use optimized rendition if it's a Civitai URL
|
|
optimized_url, was_rewritten = rewrite_preview_url(
|
|
image_url, media_type="image"
|
|
)
|
|
if was_rewritten and optimized_url:
|
|
self._logger.info(
|
|
f"Rewritten preview URL to optimized version: {optimized_url}"
|
|
)
|
|
else:
|
|
optimized_url = image_url
|
|
|
|
# Download the image using the Downloader service
|
|
self._logger.info(
|
|
f"Downloading preview from {optimized_url} for {model_path}"
|
|
)
|
|
downloader = await get_downloader()
|
|
success, preview_data, headers = await downloader.download_to_memory(
|
|
optimized_url, use_auth=False, return_headers=True
|
|
)
|
|
|
|
if not success:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"Failed to download image: {preview_data}",
|
|
},
|
|
status=502,
|
|
)
|
|
|
|
# preview_data is bytes when success is True
|
|
preview_bytes = (
|
|
preview_data
|
|
if isinstance(preview_data, bytes)
|
|
else preview_data.encode("utf-8")
|
|
)
|
|
|
|
# Determine content type from response headers
|
|
content_type = (
|
|
headers.get("Content-Type", "image/jpeg") if headers else "image/jpeg"
|
|
)
|
|
|
|
# Extract original filename from URL
|
|
original_filename = None
|
|
if "?" in image_url:
|
|
url_path = image_url.split("?")[0]
|
|
else:
|
|
url_path = image_url
|
|
original_filename = url_path.split("/")[-1] if "/" in url_path else None
|
|
|
|
result = await self._preview_service.replace_preview(
|
|
model_path=model_path,
|
|
preview_data=preview_data,
|
|
content_type=content_type,
|
|
original_filename=original_filename,
|
|
nsfw_level=nsfw_level,
|
|
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"preview_url": config.get_preview_static_url(
|
|
result["preview_path"]
|
|
),
|
|
"preview_nsfw_level": result["preview_nsfw_level"],
|
|
}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
if not image_url:
|
|
return web.json_response(
|
|
{"success": False, "error": "Image URL is required"}, status=400
|
|
)
|
|
|
|
# Download the image from the remote URL
|
|
self._logger.info(f"Downloading preview from {image_url} for {model_path}")
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(image_url) as response:
|
|
if response.status != 200:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"Failed to download image: HTTP {response.status}",
|
|
},
|
|
status=502,
|
|
)
|
|
|
|
content_type = response.headers.get("Content-Type", "image/jpeg")
|
|
preview_data = await response.read()
|
|
|
|
# Extract original filename from URL
|
|
original_filename = None
|
|
if "?" in image_url:
|
|
url_path = image_url.split("?")[0]
|
|
else:
|
|
url_path = image_url
|
|
original_filename = (
|
|
url_path.split("/")[-1] if "/" in url_path else None
|
|
)
|
|
|
|
result = await self._preview_service.replace_preview(
|
|
model_path=model_path,
|
|
preview_data=preview_bytes,
|
|
content_type=content_type,
|
|
original_filename=original_filename,
|
|
nsfw_level=nsfw_level,
|
|
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"preview_url": config.get_preview_static_url(
|
|
result["preview_path"]
|
|
),
|
|
"preview_nsfw_level": result["preview_nsfw_level"],
|
|
}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def save_metadata(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
if not file_path:
|
|
return web.Response(text="File path is required", status=400)
|
|
|
|
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
|
|
|
await self._metadata_sync.save_metadata_updates(
|
|
file_path=file_path,
|
|
updates=metadata_updates,
|
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
|
update_cache=self._service.scanner.update_single_model_cache,
|
|
)
|
|
|
|
if "model_name" in metadata_updates:
|
|
cache = await self._service.scanner.get_cached_data()
|
|
await cache.resort()
|
|
|
|
return web.json_response({"success": True})
|
|
except Exception as exc:
|
|
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def add_tags(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
new_tags = data.get("tags", [])
|
|
|
|
if not file_path:
|
|
return web.Response(text="File path is required", status=400)
|
|
|
|
if not isinstance(new_tags, list):
|
|
return web.Response(text="Tags must be a list", status=400)
|
|
|
|
tags = await self._tag_update_service.add_tags(
|
|
file_path=file_path,
|
|
new_tags=new_tags,
|
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
|
update_cache=self._service.scanner.update_single_model_cache,
|
|
)
|
|
|
|
return web.json_response({"success": True, "tags": tags})
|
|
except Exception as exc:
|
|
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def rename_model(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
new_file_name = data.get("new_file_name")
|
|
|
|
if not file_path or not new_file_name:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "File path and new file name are required",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
result = await self._lifecycle_service.rename_model(
|
|
file_path=file_path, new_file_name=new_file_name
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
**result,
|
|
"new_preview_path": config.get_preview_static_url(
|
|
result.get("new_preview_path")
|
|
),
|
|
}
|
|
)
|
|
except ValueError as exc:
|
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
|
except Exception as exc:
|
|
self._logger.error("Error renaming model: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_paths = data.get("file_paths", [])
|
|
if not file_paths:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "No file paths provided for deletion",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
result = await self._lifecycle_service.bulk_delete_models(file_paths)
|
|
return web.json_response(result)
|
|
except ValueError as exc:
|
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
|
except Exception as exc:
|
|
self._logger.error("Error in bulk delete: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_paths = data.get("file_paths", [])
|
|
|
|
if not file_paths:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "No file paths provided for verification",
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
results = await self._metadata_sync.verify_duplicate_hashes(
|
|
file_paths=file_paths,
|
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
|
hash_calculator=calculate_sha256,
|
|
update_cache=self._service.scanner.update_single_model_cache,
|
|
)
|
|
|
|
return web.json_response({"success": True, **results})
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error verifying duplicate models: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
|
|
class ModelQueryHandler:
|
|
"""Serve read-only model queries."""
|
|
|
|
def __init__(self, *, service, logger: logging.Logger) -> None:
|
|
self._service = service
|
|
self._logger = logger
|
|
|
|
async def get_top_tags(self, request: web.Request) -> web.Response:
|
|
try:
|
|
limit = int(request.query.get("limit", "20"))
|
|
if limit < 0:
|
|
limit = 20
|
|
top_tags = await self._service.get_top_tags(limit)
|
|
return web.json_response({"success": True, "tags": top_tags})
|
|
except Exception as exc:
|
|
self._logger.error("Error getting top tags: %s", exc, exc_info=True)
|
|
return web.json_response(
|
|
{"success": False, "error": "Internal server error"}, status=500
|
|
)
|
|
|
|
async def get_base_models(self, request: web.Request) -> web.Response:
|
|
try:
|
|
limit = int(request.query.get("limit", "20"))
|
|
if limit < 1 or limit > 100:
|
|
limit = 20
|
|
base_models = await self._service.get_base_models(limit)
|
|
return web.json_response({"success": True, "base_models": base_models})
|
|
except Exception as exc:
|
|
self._logger.error("Error retrieving base models: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_model_types(self, request: web.Request) -> web.Response:
|
|
try:
|
|
limit = int(request.query.get("limit", "20"))
|
|
if limit < 1 or limit > 100:
|
|
limit = 20
|
|
model_types = await self._service.get_model_types(limit)
|
|
return web.json_response({"success": True, "model_types": model_types})
|
|
except Exception as exc:
|
|
self._logger.error("Error retrieving model types: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def scan_models(self, request: web.Request) -> web.Response:
|
|
try:
|
|
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
|
|
await self._service.scan_models(
|
|
force_refresh=True, rebuild_cache=full_rebuild
|
|
)
|
|
if self._service.scanner.is_cancelled():
|
|
return web.json_response(
|
|
{
|
|
"status": "cancelled",
|
|
"message": f"{self._service.model_type.capitalize()} scan cancelled",
|
|
}
|
|
)
|
|
return web.json_response(
|
|
{
|
|
"status": "success",
|
|
"message": f"{self._service.model_type.capitalize()} scan completed",
|
|
}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error scanning %ss: %s", self._service.model_type, exc, exc_info=True
|
|
)
|
|
return web.json_response({"error": str(exc)}, status=500)
|
|
|
|
async def get_model_roots(self, request: web.Request) -> web.Response:
|
|
try:
|
|
roots = self._service.get_model_roots()
|
|
return web.json_response({"success": True, "roots": roots})
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting %s roots: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_folders(self, request: web.Request) -> web.Response:
|
|
try:
|
|
cache = await self._service.scanner.get_cached_data()
|
|
return web.json_response({"folders": cache.folders})
|
|
except Exception as exc:
|
|
self._logger.error("Error getting folders: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def cancel_task(self, request: web.Request) -> web.Response:
|
|
try:
|
|
self._service.scanner.cancel_task()
|
|
return web.json_response(
|
|
{"status": "success", "message": "Cancellation requested"}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error cancelling task for %s: %s", self._service.model_type, exc
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_folder_tree(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_root = request.query.get("model_root")
|
|
if not model_root:
|
|
return web.json_response(
|
|
{"success": False, "error": "model_root parameter is required"},
|
|
status=400,
|
|
)
|
|
folder_tree = await self._service.get_folder_tree(model_root)
|
|
return web.json_response({"success": True, "tree": folder_tree})
|
|
except Exception as exc:
|
|
self._logger.error("Error getting folder tree: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_unified_folder_tree(self, request: web.Request) -> web.Response:
|
|
try:
|
|
unified_tree = await self._service.get_unified_folder_tree()
|
|
return web.json_response({"success": True, "tree": unified_tree})
|
|
except Exception as exc:
|
|
self._logger.error("Error getting unified folder tree: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def find_duplicate_models(self, request: web.Request) -> web.Response:
|
|
try:
|
|
filters = self._parse_duplicate_filters(request)
|
|
duplicates = self._service.find_duplicate_hashes()
|
|
result = []
|
|
cache = await self._service.scanner.get_cached_data()
|
|
|
|
for sha256, paths in duplicates.items():
|
|
# Collect all models in this group
|
|
all_models = []
|
|
for path in paths:
|
|
model = next(
|
|
(m for m in cache.raw_data if m["file_path"] == path), None
|
|
)
|
|
if model:
|
|
all_models.append(model)
|
|
|
|
# Include primary if not already in paths
|
|
primary_path = self._service.get_path_by_hash(sha256)
|
|
if primary_path and primary_path not in paths:
|
|
primary_model = next(
|
|
(m for m in cache.raw_data if m["file_path"] == primary_path),
|
|
None,
|
|
)
|
|
if primary_model:
|
|
all_models.insert(0, primary_model)
|
|
|
|
# Apply filters
|
|
filtered = self._apply_duplicate_filters(all_models, filters)
|
|
|
|
# Sort: originals first, copies last
|
|
sorted_models = self._sort_duplicate_group(filtered)
|
|
|
|
# Format response
|
|
group = {"hash": sha256, "models": []}
|
|
for model in sorted_models:
|
|
group["models"].append(await self._service.format_response(model))
|
|
|
|
# Only include groups with 2+ models after filtering
|
|
if len(group["models"]) > 1:
|
|
result.append(group)
|
|
|
|
return web.json_response(
|
|
{"success": True, "duplicates": result, "count": len(result)}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error finding duplicate %ss: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
def _parse_duplicate_filters(self, request: web.Request) -> Dict[str, Any]:
|
|
"""Parse filter parameters from the request for duplicate finding."""
|
|
return {
|
|
"base_models": request.query.getall("base_model", []),
|
|
"tag_include": request.query.getall("tag_include", []),
|
|
"tag_exclude": request.query.getall("tag_exclude", []),
|
|
"model_types": request.query.getall("model_type", []),
|
|
"folder": request.query.get("folder"),
|
|
"favorites_only": request.query.get("favorites_only", "").lower() == "true",
|
|
}
|
|
|
|
def _apply_duplicate_filters(
|
|
self, models: List[Dict[str, Any]], filters: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Apply filters to a list of models within a duplicate group."""
|
|
result = models
|
|
|
|
# Apply base model filter
|
|
if filters.get("base_models"):
|
|
base_set = set(filters["base_models"])
|
|
result = [m for m in result if m.get("base_model") in base_set]
|
|
|
|
# Apply tag filters (include)
|
|
for tag in filters.get("tag_include", []):
|
|
if tag == "__no_tags__":
|
|
result = [m for m in result if not m.get("tags")]
|
|
else:
|
|
result = [m for m in result if tag in (m.get("tags") or [])]
|
|
|
|
# Apply tag filters (exclude)
|
|
for tag in filters.get("tag_exclude", []):
|
|
if tag == "__no_tags__":
|
|
result = [m for m in result if m.get("tags")]
|
|
else:
|
|
result = [m for m in result if tag not in (m.get("tags") or [])]
|
|
|
|
# Apply model type filter
|
|
if filters.get("model_types"):
|
|
type_set = {t.lower() for t in filters["model_types"]}
|
|
result = [
|
|
m for m in result if (m.get("model_type") or "").lower() in type_set
|
|
]
|
|
|
|
# Apply folder filter
|
|
if filters.get("folder"):
|
|
folder = filters["folder"]
|
|
result = [m for m in result if m.get("folder", "").startswith(folder)]
|
|
|
|
# Apply favorites filter
|
|
if filters.get("favorites_only"):
|
|
result = [m for m in result if m.get("favorite", False)]
|
|
|
|
return result
|
|
|
|
def _sort_duplicate_group(
|
|
self, models: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Sort models: originals first (left), copies (with -????. pattern) last (right)."""
|
|
if len(models) <= 1:
|
|
return models
|
|
|
|
min_len = min(len(m.get("file_name", "")) for m in models)
|
|
|
|
def copy_score(m):
|
|
fn = m.get("file_name", "")
|
|
score = 0
|
|
# Match -0001.safetensors, -1234.safetensors etc.
|
|
if re.search(r"-\d{4}\.", fn):
|
|
score += 100
|
|
# Match (1), (2) etc.
|
|
if re.search(r"\(\d+\)", fn):
|
|
score += 50
|
|
# Match 'copy' in filename
|
|
if "copy" in fn.lower():
|
|
score += 50
|
|
# Longer filenames are more likely copies
|
|
score += len(fn) - min_len
|
|
return (score, fn.lower())
|
|
|
|
return sorted(models, key=copy_score)
|
|
|
|
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
|
try:
|
|
duplicates = self._service.find_duplicate_filenames()
|
|
result = []
|
|
cache = await self._service.scanner.get_cached_data()
|
|
for filename, paths in duplicates.items():
|
|
group = {"filename": filename, "models": []}
|
|
for path in paths:
|
|
model = next(
|
|
(m for m in cache.raw_data if m["file_path"] == path), None
|
|
)
|
|
if model:
|
|
group["models"].append(
|
|
await self._service.format_response(model)
|
|
)
|
|
hash_val = self._service.scanner.get_hash_by_filename(filename)
|
|
if hash_val:
|
|
main_path = self._service.get_path_by_hash(hash_val)
|
|
if main_path and main_path not in paths:
|
|
main_model = next(
|
|
(m for m in cache.raw_data if m["file_path"] == main_path),
|
|
None,
|
|
)
|
|
if main_model:
|
|
group["models"].insert(
|
|
0, await self._service.format_response(main_model)
|
|
)
|
|
if group["models"]:
|
|
result.append(group)
|
|
return web.json_response(
|
|
{"success": True, "conflicts": result, "count": len(result)}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error finding filename conflicts for %ss: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_model_notes(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_name = request.query.get("name")
|
|
if not model_name:
|
|
return web.Response(
|
|
text=f"{self._service.model_type.capitalize()} file name is required",
|
|
status=400,
|
|
)
|
|
notes = await self._service.get_model_notes(model_name)
|
|
if notes is not None:
|
|
return web.json_response({"success": True, "notes": notes})
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"{self._service.model_type.capitalize()} not found in cache",
|
|
},
|
|
status=404,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting %s notes: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_model_preview_url(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_name = request.query.get("name")
|
|
if not model_name:
|
|
return web.Response(
|
|
text=f"{self._service.model_type.capitalize()} file name is required",
|
|
status=400,
|
|
)
|
|
include_license_flags = request.query.get(
|
|
"license_flags", ""
|
|
).strip().lower() in {"1", "true", "yes", "on"}
|
|
preview_url = await self._service.get_model_preview_url(model_name)
|
|
if preview_url:
|
|
response_payload: dict[str, object] = {
|
|
"success": True,
|
|
"preview_url": preview_url,
|
|
}
|
|
if include_license_flags:
|
|
model_data = await self._service.get_model_info_by_name(model_name)
|
|
license_flags = (model_data or {}).get("license_flags")
|
|
if license_flags is not None:
|
|
response_payload["license_flags"] = int(license_flags)
|
|
return web.json_response(response_payload)
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"No preview URL found for the specified {self._service.model_type}",
|
|
},
|
|
status=404,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting %s preview URL: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_model_civitai_url(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_name = request.query.get("name")
|
|
if not model_name:
|
|
return web.Response(
|
|
text=f"{self._service.model_type.capitalize()} file name is required",
|
|
status=400,
|
|
)
|
|
result = await self._service.get_model_civitai_url(model_name)
|
|
if result["civitai_url"]:
|
|
return web.json_response({"success": True, **result})
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"No Civitai data found for the specified {self._service.model_type}",
|
|
},
|
|
status=404,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting %s Civitai URL: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_model_metadata(self, request: web.Request) -> web.Response:
|
|
try:
|
|
file_path = request.query.get("file_path")
|
|
if not file_path:
|
|
return web.Response(text="File path is required", status=400)
|
|
metadata = await self._service.get_model_metadata(file_path)
|
|
if metadata is not None:
|
|
return web.json_response({"success": True, "metadata": metadata})
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"{self._service.model_type.capitalize()} not found or no CivitAI metadata available",
|
|
},
|
|
status=404,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting %s metadata: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_model_description(self, request: web.Request) -> web.Response:
|
|
try:
|
|
file_path = request.query.get("file_path")
|
|
if not file_path:
|
|
return web.Response(text="File path is required", status=400)
|
|
description = await self._service.get_model_description(file_path)
|
|
if description is not None:
|
|
return web.json_response({"success": True, "description": description})
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": f"{self._service.model_type.capitalize()} not found or no description available",
|
|
},
|
|
status=404,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting %s description: %s",
|
|
self._service.model_type,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_relative_paths(self, request: web.Request) -> web.Response:
|
|
try:
|
|
search = request.query.get("search", "").strip()
|
|
limit = min(int(request.query.get("limit", "15")), 100)
|
|
offset = max(0, int(request.query.get("offset", "0")))
|
|
matching_paths = await self._service.search_relative_paths(
|
|
search, limit, offset
|
|
)
|
|
return web.json_response(
|
|
{"success": True, "relative_paths": matching_paths}
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting relative paths for autocomplete: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
|
|
class ModelDownloadHandler:
|
|
"""Coordinate downloads and progress reporting."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ws_manager: WebSocketManager,
|
|
logger: logging.Logger,
|
|
download_use_case: DownloadModelUseCase,
|
|
download_coordinator: DownloadCoordinator,
|
|
) -> None:
|
|
self._ws_manager = ws_manager
|
|
self._logger = logger
|
|
self._download_use_case = download_use_case
|
|
self._download_coordinator = download_coordinator
|
|
|
|
async def download_model(self, request: web.Request) -> web.Response:
|
|
try:
|
|
payload = await request.json()
|
|
result = await self._download_use_case.execute(payload)
|
|
if not result.get("success", False):
|
|
return web.json_response(result, status=500)
|
|
return web.json_response(result)
|
|
except DownloadModelValidationError as exc:
|
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
|
except DownloadModelEarlyAccessError as exc:
|
|
self._logger.warning("Early access error: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
|
except Exception as exc:
|
|
error_message = str(exc)
|
|
self._logger.error(
|
|
"Error downloading model: %s", error_message, exc_info=True
|
|
)
|
|
return web.json_response(
|
|
{"success": False, "error": error_message}, status=500
|
|
)
|
|
|
|
async def download_model_get(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_id = request.query.get("model_id")
|
|
if not model_id:
|
|
return web.Response(
|
|
status=400,
|
|
text="Missing required parameter: Please provide 'model_id'",
|
|
)
|
|
|
|
model_version_id = request.query.get("model_version_id")
|
|
download_id = request.query.get("download_id")
|
|
use_default_paths = (
|
|
request.query.get("use_default_paths", "false").lower() == "true"
|
|
)
|
|
source = request.query.get("source")
|
|
file_params_json = request.query.get("file_params")
|
|
|
|
data = {"model_id": model_id, "use_default_paths": use_default_paths}
|
|
if model_version_id:
|
|
data["model_version_id"] = model_version_id
|
|
if download_id:
|
|
data["download_id"] = download_id
|
|
if source:
|
|
data["source"] = source
|
|
if file_params_json:
|
|
import json
|
|
|
|
try:
|
|
data["file_params"] = json.loads(file_params_json)
|
|
except json.JSONDecodeError:
|
|
self._logger.warning(
|
|
"Invalid file_params JSON: %s", file_params_json
|
|
)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
future = loop.create_future()
|
|
future.set_result(data)
|
|
|
|
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
|
result = await self._download_use_case.execute(data)
|
|
if not result.get("success", False):
|
|
return web.json_response(result, status=500)
|
|
return web.json_response(result)
|
|
except DownloadModelValidationError as exc:
|
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
|
except DownloadModelEarlyAccessError as exc:
|
|
self._logger.warning("Early access error: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error downloading model via GET: %s", exc, exc_info=True
|
|
)
|
|
return web.Response(status=500, text=str(exc))
|
|
|
|
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
|
try:
|
|
download_id = request.query.get("download_id")
|
|
if not download_id:
|
|
return web.json_response(
|
|
{"success": False, "error": "Download ID is required"}, status=400
|
|
)
|
|
result = await self._download_coordinator.cancel_download(download_id)
|
|
return web.json_response(result)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error cancelling download via GET: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def pause_download_get(self, request: web.Request) -> web.Response:
|
|
try:
|
|
download_id = request.query.get("download_id")
|
|
if not download_id:
|
|
return web.json_response(
|
|
{"success": False, "error": "Download ID is required"}, status=400
|
|
)
|
|
result = await self._download_coordinator.pause_download(download_id)
|
|
status = 200 if result.get("success") else 400
|
|
return web.json_response(result, status=status)
|
|
except Exception as exc:
|
|
self._logger.error("Error pausing download via GET: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def resume_download_get(self, request: web.Request) -> web.Response:
|
|
try:
|
|
download_id = request.query.get("download_id")
|
|
if not download_id:
|
|
return web.json_response(
|
|
{"success": False, "error": "Download ID is required"}, status=400
|
|
)
|
|
result = await self._download_coordinator.resume_download(download_id)
|
|
status = 200 if result.get("success") else 400
|
|
return web.json_response(result, status=status)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error resuming download via GET: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_download_progress(self, request: web.Request) -> web.Response:
|
|
try:
|
|
download_id = request.match_info.get("download_id")
|
|
if not download_id:
|
|
return web.json_response(
|
|
{"success": False, "error": "Download ID is required"}, status=400
|
|
)
|
|
progress_data = self._ws_manager.get_download_progress(download_id)
|
|
if progress_data is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "Download ID not found"}, status=404
|
|
)
|
|
response_payload = {
|
|
"success": True,
|
|
"progress": progress_data.get("progress", 0),
|
|
"bytes_downloaded": progress_data.get("bytes_downloaded"),
|
|
"total_bytes": progress_data.get("total_bytes"),
|
|
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
|
|
}
|
|
|
|
status = progress_data.get("status")
|
|
if status and status != "progress":
|
|
response_payload["status"] = status
|
|
if "message" in progress_data:
|
|
response_payload["message"] = progress_data["message"]
|
|
elif status is None and "message" in progress_data:
|
|
response_payload["message"] = progress_data["message"]
|
|
|
|
return web.json_response(response_payload)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting download progress: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
|
|
class ModelCivitaiHandler:
|
|
"""CivitAI integration endpoints."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
service,
|
|
settings_service: SettingsManager,
|
|
ws_manager: WebSocketManager,
|
|
logger: logging.Logger,
|
|
metadata_provider_factory: Callable[[], Awaitable],
|
|
validate_model_type: Callable[[str], bool],
|
|
expected_model_types: Callable[[], str],
|
|
find_model_file: Callable[
|
|
[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]
|
|
],
|
|
metadata_sync: MetadataSyncService,
|
|
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
|
|
metadata_progress_callback: MetadataRefreshProgressReporter,
|
|
) -> None:
|
|
self._service = service
|
|
self._settings = settings_service
|
|
self._ws_manager = ws_manager
|
|
self._logger = logger
|
|
self._metadata_provider_factory = metadata_provider_factory
|
|
self._validate_model_type = validate_model_type
|
|
self._expected_model_types = expected_model_types
|
|
self._find_model_file = find_model_file
|
|
self._metadata_sync = metadata_sync
|
|
self._metadata_refresh_use_case = metadata_refresh_use_case
|
|
self._metadata_progress_callback = metadata_progress_callback
|
|
|
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
|
try:
|
|
result = await self._metadata_refresh_use_case.execute_with_error_handling(
|
|
progress_callback=self._metadata_progress_callback
|
|
)
|
|
return web.json_response(result)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc
|
|
)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_id = request.match_info["model_id"]
|
|
metadata_provider = await self._metadata_provider_factory()
|
|
try:
|
|
response = await metadata_provider.get_model_versions(model_id)
|
|
except ResourceNotFoundError:
|
|
return web.Response(status=404, text="Model not found")
|
|
if not response or not response.get("modelVersions"):
|
|
return web.Response(status=404, text="Model not found")
|
|
|
|
versions = response.get("modelVersions", [])
|
|
model_type = response.get("type", "")
|
|
if not self._validate_model_type(model_type):
|
|
return web.json_response(
|
|
{
|
|
"error": f"Model type mismatch. Expected {self._expected_model_types()}, got {model_type}"
|
|
},
|
|
status=400,
|
|
)
|
|
|
|
cache = await self._service.scanner.get_cached_data()
|
|
version_index = cache.version_index
|
|
|
|
for version in versions:
|
|
version_id = None
|
|
version_id_raw = version.get("id")
|
|
if version_id_raw is not None:
|
|
try:
|
|
version_id = int(str(version_id_raw))
|
|
except (TypeError, ValueError):
|
|
version_id = None
|
|
|
|
cache_entry = (
|
|
version_index.get(version_id)
|
|
if (version_id is not None and version_index)
|
|
else None
|
|
)
|
|
version["existsLocally"] = cache_entry is not None
|
|
if cache_entry and isinstance(cache_entry, Mapping):
|
|
local_path = cache_entry.get("file_path")
|
|
if local_path:
|
|
version["localPath"] = local_path
|
|
else:
|
|
version.pop("localPath", None)
|
|
|
|
model_file = (
|
|
self._find_model_file(version.get("files", []))
|
|
if isinstance(version.get("files"), Iterable)
|
|
else None
|
|
)
|
|
if model_file and isinstance(model_file, Mapping):
|
|
version["modelSizeKB"] = model_file.get("sizeKB")
|
|
return web.json_response(versions)
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error fetching %s model versions: %s", self._service.model_type, exc
|
|
)
|
|
return web.Response(status=500, text=str(exc))
|
|
|
|
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
|
try:
|
|
model_version_id = request.match_info.get("modelVersionId")
|
|
metadata_provider = await self._metadata_provider_factory()
|
|
model, error_msg = await metadata_provider.get_model_version_info(
|
|
model_version_id
|
|
)
|
|
if not model:
|
|
self._logger.warning(
|
|
"Failed to fetch model version %s: %s", model_version_id, error_msg
|
|
)
|
|
status_code = (
|
|
404 if error_msg and "not found" in error_msg.lower() else 500
|
|
)
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": error_msg or "Failed to fetch model information",
|
|
},
|
|
status=status_code,
|
|
)
|
|
return web.json_response(model)
|
|
except Exception as exc:
|
|
self._logger.error("Error fetching model details: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
|
try:
|
|
hash_value = request.match_info.get("hash")
|
|
metadata_provider = await self._metadata_provider_factory()
|
|
model, error = await metadata_provider.get_model_by_hash(hash_value)
|
|
if error:
|
|
self._logger.warning("Error getting model by hash: %s", error)
|
|
return web.json_response({"success": False, "error": error}, status=404)
|
|
return web.json_response(model)
|
|
except Exception as exc:
|
|
self._logger.error("Error fetching model details by hash: %s", exc)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
|
|
class ModelMoveHandler:
|
|
"""Move model files between folders."""
|
|
|
|
def __init__(
|
|
self, *, move_service: ModelMoveService, logger: logging.Logger
|
|
) -> None:
|
|
self._move_service = move_service
|
|
self._logger = logger
|
|
|
|
async def move_model(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get("file_path")
|
|
target_path = data.get("target_path")
|
|
use_default_paths = data.get("use_default_paths", False)
|
|
if not file_path or not target_path:
|
|
return web.Response(
|
|
text="File path and target path are required", status=400
|
|
)
|
|
result = await self._move_service.move_model(
|
|
file_path, target_path, use_default_paths=use_default_paths
|
|
)
|
|
status = 200 if result.get("success") else 500
|
|
return web.json_response(result, status=status)
|
|
except Exception as exc:
|
|
self._logger.error("Error moving model: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
|
try:
|
|
data = await request.json()
|
|
file_paths = data.get("file_paths", [])
|
|
target_path = data.get("target_path")
|
|
use_default_paths = data.get("use_default_paths", False)
|
|
if not file_paths or not target_path:
|
|
return web.Response(
|
|
text="File paths and target path are required", status=400
|
|
)
|
|
result = await self._move_service.move_models_bulk(
|
|
file_paths, target_path, use_default_paths=use_default_paths
|
|
)
|
|
return web.json_response(result)
|
|
except Exception as exc:
|
|
self._logger.error("Error moving models in bulk: %s", exc, exc_info=True)
|
|
return web.Response(text=str(exc), status=500)
|
|
|
|
|
|
class ModelAutoOrganizeHandler:
|
|
"""Manage auto-organize operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
use_case: AutoOrganizeUseCase,
|
|
progress_callback: WebSocketProgressCallback,
|
|
ws_manager: WebSocketManager,
|
|
logger: logging.Logger,
|
|
) -> None:
|
|
self._use_case = use_case
|
|
self._progress_callback = progress_callback
|
|
self._ws_manager = ws_manager
|
|
self._logger = logger
|
|
|
|
async def auto_organize_models(self, request: web.Request) -> web.Response:
|
|
try:
|
|
file_paths = None
|
|
exclusion_patterns = None
|
|
settings_manager = get_settings_manager()
|
|
if request.method == "POST":
|
|
try:
|
|
data = await request.json()
|
|
file_paths = data.get("file_paths")
|
|
if "exclusion_patterns" in data:
|
|
exclusion_patterns = (
|
|
settings_manager.normalize_auto_organize_exclusions(
|
|
data.get("exclusion_patterns")
|
|
)
|
|
)
|
|
except Exception: # pragma: no cover - permissive path
|
|
pass
|
|
|
|
result = await self._use_case.execute(
|
|
file_paths=file_paths,
|
|
progress_callback=self._progress_callback,
|
|
exclusion_patterns=exclusion_patterns,
|
|
)
|
|
return web.json_response(result.to_dict())
|
|
except AutoOrganizeInProgressError:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "Auto-organize is already running. Please wait for it to complete.",
|
|
},
|
|
status=409,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
|
|
try:
|
|
await self._progress_callback.on_progress(
|
|
{
|
|
"type": "auto_organize_progress",
|
|
"status": "error",
|
|
"error": str(exc),
|
|
}
|
|
)
|
|
except Exception: # pragma: no cover - defensive reporting
|
|
pass
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:
|
|
try:
|
|
progress_data = self._ws_manager.get_auto_organize_progress()
|
|
if progress_data is None:
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"error": "No auto-organize operation in progress",
|
|
},
|
|
status=404,
|
|
)
|
|
return web.json_response({"success": True, "progress": progress_data})
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Error getting auto-organize progress: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
|
|
class ModelUpdateHandler:
|
|
"""Handle update tracking requests."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
service,
|
|
update_service,
|
|
metadata_provider_selector,
|
|
settings_service,
|
|
logger: logging.Logger,
|
|
) -> None:
|
|
self._service = service
|
|
self._update_service = update_service
|
|
self._metadata_provider_selector = metadata_provider_selector
|
|
self._settings = settings_service
|
|
self._logger = logger
|
|
|
|
async def fetch_missing_civitai_license_data(
|
|
self, request: web.Request
|
|
) -> web.Response:
|
|
payload = await self._read_json(request)
|
|
target_model_ids = self._extract_target_model_ids(payload)
|
|
|
|
provider = await self._get_civitai_provider()
|
|
if provider is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "Civitai provider not available"},
|
|
status=503,
|
|
)
|
|
|
|
try:
|
|
cache = await self._service.scanner.get_cached_data()
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Failed to load cache for license refresh: %s", exc, exc_info=True
|
|
)
|
|
cache = None
|
|
|
|
target_set = set(target_model_ids) if target_model_ids is not None else None
|
|
candidates = await self._collect_models_missing_license(cache, target_set)
|
|
if not candidates:
|
|
return web.json_response({"success": True, "updated": []})
|
|
|
|
model_ids = sorted(candidates.keys())
|
|
try:
|
|
license_map = await self._fetch_license_info(provider, model_ids)
|
|
except RateLimitError as exc:
|
|
return web.json_response(
|
|
{"success": False, "error": str(exc) or "Rate limited"},
|
|
status=429,
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive log
|
|
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
updated: List[Dict[str, str]] = []
|
|
errors: List[Dict[str, str]] = []
|
|
for model_id in model_ids:
|
|
license_payload = license_map.get(model_id)
|
|
if not license_payload:
|
|
continue
|
|
resolved_payload = resolve_license_payload(license_payload)
|
|
for context in candidates.get(model_id, []):
|
|
metadata_path = context["file_path"]
|
|
metadata_payload = context["metadata"]
|
|
civitai_section = metadata_payload.setdefault("civitai", {})
|
|
model_section = civitai_section.get("model")
|
|
if not isinstance(model_section, Mapping):
|
|
model_section = {}
|
|
model_section.update(resolved_payload)
|
|
civitai_section["model"] = model_section
|
|
metadata_payload["civitai"] = civitai_section
|
|
try:
|
|
await MetadataManager.save_metadata(metadata_path, metadata_payload)
|
|
updated.append({"modelId": model_id, "filePath": metadata_path})
|
|
except Exception as exc:
|
|
self._logger.error(
|
|
"Failed to save metadata for %s: %s",
|
|
metadata_path,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
errors.append({"filePath": metadata_path, "error": str(exc)})
|
|
|
|
response_payload = {"success": True, "updated": updated}
|
|
missing_model_ids = [mid for mid in model_ids if mid not in license_map]
|
|
if missing_model_ids:
|
|
response_payload["missingModelIds"] = missing_model_ids
|
|
if errors:
|
|
response_payload["errors"] = errors
|
|
return web.json_response(response_payload)
|
|
|
|
async def refresh_model_updates(self, request: web.Request) -> web.Response:
|
|
payload = await self._read_json(request)
|
|
force_refresh = self._parse_bool(
|
|
request.query.get("force")
|
|
) or self._parse_bool(payload.get("force"))
|
|
|
|
raw_model_ids = payload.get("modelIds")
|
|
if raw_model_ids is None:
|
|
raw_model_ids = payload.get("model_ids")
|
|
|
|
target_model_ids: list[int] = []
|
|
if isinstance(raw_model_ids, (list, tuple, set)):
|
|
for value in raw_model_ids:
|
|
normalized = self._normalize_model_id(value)
|
|
if normalized is not None:
|
|
target_model_ids.append(normalized)
|
|
|
|
if target_model_ids:
|
|
target_model_ids = sorted(set(target_model_ids))
|
|
|
|
provider = await self._get_civitai_provider()
|
|
if provider is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "Civitai provider not available"},
|
|
status=503,
|
|
)
|
|
|
|
try:
|
|
records = await self._update_service.refresh_for_model_type(
|
|
self._service.model_type,
|
|
self._service.scanner,
|
|
provider,
|
|
force_refresh=force_refresh,
|
|
target_model_ids=target_model_ids or None,
|
|
)
|
|
if self._service.scanner.is_cancelled():
|
|
return web.json_response(
|
|
{
|
|
"success": False,
|
|
"status": "cancelled",
|
|
"message": "Update refresh cancelled",
|
|
}
|
|
)
|
|
except RateLimitError as exc:
|
|
return web.json_response(
|
|
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
|
)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._logger.error(
|
|
"Failed to refresh model updates: %s", exc, exc_info=True
|
|
)
|
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
|
|
|
serialized_records = []
|
|
for record in records.values():
|
|
has_update_fn = getattr(record, "has_update", None)
|
|
if callable(has_update_fn) and has_update_fn():
|
|
serialized_records.append(self._serialize_record(record))
|
|
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"records": serialized_records,
|
|
}
|
|
)
|
|
|
|
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
|
|
payload = await self._read_json(request)
|
|
model_id = self._normalize_model_id(payload.get("modelId"))
|
|
if model_id is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "modelId is required"}, status=400
|
|
)
|
|
|
|
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
|
record = await self._update_service.set_should_ignore(
|
|
self._service.model_type, model_id, should_ignore
|
|
)
|
|
return web.json_response(
|
|
{"success": True, "record": self._serialize_record(record)}
|
|
)
|
|
|
|
async def set_version_update_ignore(self, request: web.Request) -> web.Response:
|
|
payload = await self._read_json(request)
|
|
model_id = self._normalize_model_id(payload.get("modelId"))
|
|
version_id = self._normalize_model_id(payload.get("versionId"))
|
|
if model_id is None or version_id is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "modelId and versionId are required"},
|
|
status=400,
|
|
)
|
|
|
|
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
|
|
record = await self._update_service.set_version_should_ignore(
|
|
self._service.model_type,
|
|
model_id,
|
|
version_id,
|
|
should_ignore,
|
|
)
|
|
overrides = await self._build_version_context(record)
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"record": self._serialize_record(record, version_context=overrides),
|
|
}
|
|
)
|
|
|
|
async def get_model_update_status(self, request: web.Request) -> web.Response:
|
|
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
|
if model_id is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "model_id must be an integer"}, status=400
|
|
)
|
|
|
|
refresh = self._parse_bool(request.query.get("refresh"))
|
|
force = self._parse_bool(request.query.get("force"))
|
|
|
|
try:
|
|
record = await self._get_or_refresh_record(
|
|
model_id, refresh=refresh, force=force
|
|
)
|
|
except RateLimitError as exc:
|
|
return web.json_response(
|
|
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
|
)
|
|
|
|
if record is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "Model not tracked"}, status=404
|
|
)
|
|
|
|
return web.json_response(
|
|
{"success": True, "record": self._serialize_record(record)}
|
|
)
|
|
|
|
async def get_model_versions(self, request: web.Request) -> web.Response:
|
|
model_id = self._normalize_model_id(request.match_info.get("model_id"))
|
|
if model_id is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "model_id must be an integer"}, status=400
|
|
)
|
|
|
|
refresh = self._parse_bool(request.query.get("refresh"))
|
|
force = self._parse_bool(request.query.get("force"))
|
|
|
|
try:
|
|
record = await self._get_or_refresh_record(
|
|
model_id, refresh=refresh, force=force
|
|
)
|
|
except RateLimitError as exc:
|
|
return web.json_response(
|
|
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
|
)
|
|
|
|
if record is None:
|
|
return web.json_response(
|
|
{"success": False, "error": "Model not tracked"}, status=404
|
|
)
|
|
|
|
# Enrich EA versions with detailed info if needed
|
|
record = await self._enrich_early_access_details(record)
|
|
|
|
overrides = await self._build_version_context(record)
|
|
return web.json_response(
|
|
{
|
|
"success": True,
|
|
"record": self._serialize_record(record, version_context=overrides),
|
|
}
|
|
)
|
|
|
|
async def _get_or_refresh_record(
|
|
self, model_id: int, *, refresh: bool, force: bool
|
|
) -> Optional[object]:
|
|
record = await self._update_service.get_record(
|
|
self._service.model_type, model_id
|
|
)
|
|
if record and not refresh and not force:
|
|
return record
|
|
|
|
provider = await self._get_civitai_provider()
|
|
if provider is None:
|
|
return record
|
|
|
|
return await self._update_service.refresh_single_model(
|
|
self._service.model_type,
|
|
model_id,
|
|
self._service.scanner,
|
|
provider,
|
|
force_refresh=force or refresh,
|
|
)
|
|
|
|
async def _get_civitai_provider(self):
|
|
try:
|
|
return await self._metadata_provider_selector("civitai_api")
|
|
except Exception as exc: # pragma: no cover - defensive log
|
|
self._logger.error(
|
|
"Failed to acquire civitai provider: %s", exc, exc_info=True
|
|
)
|
|
return None
|
|
|
|
async def _enrich_early_access_details(self, record):
|
|
"""Fetch detailed EA info for versions missing exact end time.
|
|
|
|
Identifies versions with is_early_access=True but no early_access_ends_at,
|
|
then fetches detailed info from CivitAI to get the exact end time.
|
|
"""
|
|
if not record or not record.versions:
|
|
return record
|
|
|
|
# Find versions that need enrichment
|
|
versions_needing_update = []
|
|
for version in record.versions:
|
|
if version.is_early_access and not version.early_access_ends_at:
|
|
versions_needing_update.append(version)
|
|
|
|
if not versions_needing_update:
|
|
return record
|
|
|
|
provider = await self._get_civitai_provider()
|
|
if not provider:
|
|
return record
|
|
|
|
# Fetch detailed info for each version needing update
|
|
updated_versions = []
|
|
for version in versions_needing_update:
|
|
try:
|
|
version_info, error = await provider.get_model_version_info(
|
|
str(version.version_id)
|
|
)
|
|
if version_info and not error:
|
|
ea_ends_at = version_info.get("earlyAccessEndsAt")
|
|
if ea_ends_at:
|
|
# Create updated version with EA end time
|
|
from dataclasses import replace
|
|
|
|
updated_version = replace(
|
|
version, early_access_ends_at=ea_ends_at
|
|
)
|
|
updated_versions.append(updated_version)
|
|
self._logger.debug(
|
|
"Enriched EA info for version %s: %s",
|
|
version.version_id,
|
|
ea_ends_at,
|
|
)
|
|
except Exception as exc:
|
|
self._logger.debug(
|
|
"Failed to fetch EA details for version %s: %s",
|
|
version.version_id,
|
|
exc,
|
|
)
|
|
|
|
if not updated_versions:
|
|
return record
|
|
|
|
# Update record with enriched versions
|
|
version_map = {v.version_id: v for v in record.versions}
|
|
for updated in updated_versions:
|
|
version_map[updated.version_id] = updated
|
|
|
|
# Create new record with updated versions
|
|
from dataclasses import replace
|
|
|
|
new_record = replace(
|
|
record,
|
|
versions=list(version_map.values()),
|
|
)
|
|
|
|
# Optionally persist to database for caching
|
|
# Note: We don't persist here to avoid side effects; the data will be
|
|
# refreshed on next bulk update if still needed
|
|
|
|
return new_record
|
|
|
|
async def _collect_models_missing_license(
|
|
self,
|
|
cache,
|
|
target_model_ids: Optional[set[int]],
|
|
) -> Dict[int, List[Dict[str, Any]]]:
|
|
entries: Dict[int, List[Dict[str, Any]]] = {}
|
|
if cache is None:
|
|
return entries
|
|
|
|
raw_data = getattr(cache, "raw_data", None) or []
|
|
seen_paths: set[str] = set()
|
|
target_set = target_model_ids
|
|
|
|
for item in raw_data:
|
|
if not isinstance(item, Mapping):
|
|
continue
|
|
file_path = item.get("file_path")
|
|
if (
|
|
not isinstance(file_path, str)
|
|
or not file_path
|
|
or file_path in seen_paths
|
|
):
|
|
continue
|
|
seen_paths.add(file_path)
|
|
|
|
civitai_entry = item.get("civitai")
|
|
if not isinstance(civitai_entry, Mapping):
|
|
continue
|
|
|
|
model_id = self._normalize_model_id(civitai_entry.get("modelId"))
|
|
if model_id is None:
|
|
continue
|
|
if target_set is not None and model_id not in target_set:
|
|
continue
|
|
|
|
try:
|
|
metadata_obj, should_skip = await MetadataManager.load_metadata(
|
|
file_path
|
|
)
|
|
except Exception as exc:
|
|
self._logger.debug("Failed to load metadata for %s: %s", file_path, exc)
|
|
continue
|
|
if metadata_obj is None or should_skip:
|
|
continue
|
|
|
|
metadata_payload = self._convert_metadata_to_dict(metadata_obj)
|
|
civitai_payload = metadata_payload.get("civitai")
|
|
if not isinstance(civitai_payload, Mapping):
|
|
civitai_payload = {}
|
|
|
|
model_payload = civitai_payload.get("model")
|
|
if not isinstance(model_payload, Mapping):
|
|
model_payload = {}
|
|
|
|
missing = [key for key in LICENSE_FIELDS if key not in model_payload]
|
|
if not missing:
|
|
continue
|
|
|
|
civitai_payload["model"] = model_payload
|
|
metadata_payload["civitai"] = civitai_payload
|
|
entries.setdefault(model_id, []).append(
|
|
{"file_path": file_path, "metadata": metadata_payload}
|
|
)
|
|
|
|
return entries
|
|
|
|
async def _fetch_license_info(
|
|
self,
|
|
provider,
|
|
model_ids: List[int],
|
|
) -> Dict[int, Dict[str, Any]]:
|
|
if not model_ids:
|
|
return {}
|
|
|
|
BATCH_SIZE = 100
|
|
aggregated: Dict[int, Dict[str, Any]] = {}
|
|
for start in range(0, len(model_ids), BATCH_SIZE):
|
|
chunk = model_ids[start : start + BATCH_SIZE]
|
|
response = await provider.get_model_versions_bulk(chunk)
|
|
if not isinstance(response, Mapping):
|
|
continue
|
|
|
|
for raw_id, payload in response.items():
|
|
normalized_id = self._normalize_model_id(raw_id)
|
|
if normalized_id is None or not isinstance(payload, Mapping):
|
|
continue
|
|
license_data: Dict[str, Any] = {}
|
|
for field in LICENSE_FIELDS:
|
|
license_data[field] = payload.get(field)
|
|
aggregated[normalized_id] = license_data
|
|
|
|
return aggregated
|
|
|
|
def _extract_target_model_ids(self, payload: Dict) -> Optional[List[int]]:
|
|
if not isinstance(payload, Mapping):
|
|
return None
|
|
|
|
raw_ids = payload.get("modelIds")
|
|
if raw_ids is None:
|
|
raw_ids = payload.get("model_ids")
|
|
|
|
if not isinstance(raw_ids, (list, tuple, set)):
|
|
return None
|
|
|
|
normalized: List[int] = []
|
|
for candidate in raw_ids:
|
|
model_id = self._normalize_model_id(candidate)
|
|
if model_id is not None:
|
|
normalized.append(model_id)
|
|
|
|
if not normalized:
|
|
return None
|
|
|
|
return sorted(set(normalized))
|
|
|
|
@staticmethod
|
|
def _convert_metadata_to_dict(metadata: Any) -> Dict[str, Any]:
|
|
if metadata is None:
|
|
return {}
|
|
|
|
to_dict = getattr(metadata, "to_dict", None)
|
|
if callable(to_dict):
|
|
try:
|
|
return to_dict()
|
|
except Exception:
|
|
pass
|
|
|
|
if isinstance(metadata, Mapping):
|
|
return dict(metadata)
|
|
|
|
return {}
|
|
|
|
async def _read_json(self, request: web.Request) -> Dict:
|
|
if not request.can_read_body:
|
|
return {}
|
|
try:
|
|
return await request.json()
|
|
except Exception:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def _parse_bool(value) -> bool:
|
|
if isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, str):
|
|
return value.lower() in {"1", "true", "yes"}
|
|
if isinstance(value, (int, float)):
|
|
return bool(value)
|
|
return False
|
|
|
|
@staticmethod
|
|
def _normalize_model_id(value) -> Optional[int]:
|
|
try:
|
|
if value is None:
|
|
return None
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
def _serialize_record(
|
|
self,
|
|
record,
|
|
*,
|
|
version_context: Optional[Dict[int, Dict[str, Optional[str]]]] = None,
|
|
) -> Dict:
|
|
context = version_context or {}
|
|
# Check user setting for hiding early access versions
|
|
hide_early_access = False
|
|
if self._settings is not None:
|
|
try:
|
|
hide_early_access = bool(
|
|
self._settings.get("hide_early_access_updates", False)
|
|
)
|
|
except Exception:
|
|
pass
|
|
return {
|
|
"modelType": record.model_type,
|
|
"modelId": record.model_id,
|
|
"largestVersionId": record.largest_version_id,
|
|
"versionIds": record.version_ids,
|
|
"inLibraryVersionIds": record.in_library_version_ids,
|
|
"lastCheckedAt": record.last_checked_at,
|
|
"shouldIgnore": record.should_ignore_model,
|
|
"hasUpdate": record.has_update(hide_early_access=hide_early_access),
|
|
"versions": [
|
|
self._serialize_version(version, context.get(version.version_id))
|
|
for version in record.versions
|
|
],
|
|
}
|
|
|
|
@staticmethod
|
|
def _serialize_version(
|
|
version, context: Optional[Dict[str, Optional[str]]]
|
|
) -> Dict:
|
|
context = context or {}
|
|
preview_override = context.get("preview_override")
|
|
preview_url = (
|
|
preview_override if preview_override is not None else version.preview_url
|
|
)
|
|
|
|
# Determine if version is currently in early access
|
|
# Two-phase detection: use exact end time if available, otherwise fallback to basic flag
|
|
is_early_access = False
|
|
if version.early_access_ends_at:
|
|
try:
|
|
from datetime import datetime, timezone
|
|
|
|
ea_date = datetime.fromisoformat(
|
|
version.early_access_ends_at.replace("Z", "+00:00")
|
|
)
|
|
is_early_access = ea_date > datetime.now(timezone.utc)
|
|
except (ValueError, AttributeError):
|
|
# If date parsing fails, treat as active EA (conservative)
|
|
is_early_access = True
|
|
elif getattr(version, "is_early_access", False):
|
|
# Fallback to basic EA flag from bulk API
|
|
is_early_access = True
|
|
|
|
return {
|
|
"versionId": version.version_id,
|
|
"name": version.name,
|
|
"baseModel": version.base_model,
|
|
"releasedAt": version.released_at,
|
|
"sizeBytes": version.size_bytes,
|
|
"previewUrl": preview_url,
|
|
"isInLibrary": version.is_in_library,
|
|
"shouldIgnore": version.should_ignore,
|
|
"earlyAccessEndsAt": version.early_access_ends_at,
|
|
"isEarlyAccess": is_early_access,
|
|
"filePath": context.get("file_path"),
|
|
"fileName": context.get("file_name"),
|
|
}
|
|
|
|
async def _build_version_context(
|
|
self, record
|
|
) -> Dict[int, Dict[str, Optional[str]]]:
|
|
context: Dict[int, Dict[str, Optional[str]]] = {}
|
|
try:
|
|
cache = await self._service.scanner.get_cached_data()
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
self._logger.debug(
|
|
"Failed to load cache while building preview overrides: %s", exc
|
|
)
|
|
return context
|
|
|
|
version_index = getattr(cache, "version_index", None)
|
|
if not version_index:
|
|
return context
|
|
|
|
for version in record.versions:
|
|
if not version.is_in_library:
|
|
continue
|
|
cache_entry = version_index.get(version.version_id)
|
|
if isinstance(cache_entry, Mapping):
|
|
preview = cache_entry.get("preview_url")
|
|
context_entry: Dict[str, Optional[str]] = {
|
|
"file_path": cache_entry.get("file_path"),
|
|
"file_name": cache_entry.get("file_name"),
|
|
"preview_override": None,
|
|
}
|
|
if isinstance(preview, str) and preview:
|
|
context_entry["preview_override"] = config.get_preview_static_url(
|
|
preview
|
|
)
|
|
context[version.version_id] = context_entry
|
|
return context
|
|
|
|
|
|
@dataclass
|
|
class ModelHandlerSet:
|
|
"""Aggregate concrete handlers into a flat mapping."""
|
|
|
|
page_view: ModelPageView
|
|
listing: ModelListingHandler
|
|
management: ModelManagementHandler
|
|
query: ModelQueryHandler
|
|
download: ModelDownloadHandler
|
|
civitai: ModelCivitaiHandler
|
|
move: ModelMoveHandler
|
|
auto_organize: ModelAutoOrganizeHandler
|
|
updates: ModelUpdateHandler
|
|
|
|
def to_route_mapping(
|
|
self,
|
|
) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
|
|
return {
|
|
"handle_models_page": self.page_view.handle,
|
|
"get_models": self.listing.get_models,
|
|
"delete_model": self.management.delete_model,
|
|
"exclude_model": self.management.exclude_model,
|
|
"fetch_civitai": self.management.fetch_civitai,
|
|
"fetch_all_civitai": self.civitai.fetch_all_civitai,
|
|
"relink_civitai": self.management.relink_civitai,
|
|
"replace_preview": self.management.replace_preview,
|
|
"set_preview_from_url": self.management.set_preview_from_url,
|
|
"save_metadata": self.management.save_metadata,
|
|
"add_tags": self.management.add_tags,
|
|
"rename_model": self.management.rename_model,
|
|
"bulk_delete_models": self.management.bulk_delete_models,
|
|
"verify_duplicates": self.management.verify_duplicates,
|
|
"get_top_tags": self.query.get_top_tags,
|
|
"get_base_models": self.query.get_base_models,
|
|
"get_model_types": self.query.get_model_types,
|
|
"scan_models": self.query.scan_models,
|
|
"get_model_roots": self.query.get_model_roots,
|
|
"get_folders": self.query.get_folders,
|
|
"get_folder_tree": self.query.get_folder_tree,
|
|
"get_unified_folder_tree": self.query.get_unified_folder_tree,
|
|
"find_duplicate_models": self.query.find_duplicate_models,
|
|
"find_filename_conflicts": self.query.find_filename_conflicts,
|
|
"download_model": self.download.download_model,
|
|
"download_model_get": self.download.download_model_get,
|
|
"cancel_download_get": self.download.cancel_download_get,
|
|
"pause_download_get": self.download.pause_download_get,
|
|
"resume_download_get": self.download.resume_download_get,
|
|
"get_download_progress": self.download.get_download_progress,
|
|
"get_civitai_versions": self.civitai.get_civitai_versions,
|
|
"get_civitai_model_by_version": self.civitai.get_civitai_model_by_version,
|
|
"get_civitai_model_by_hash": self.civitai.get_civitai_model_by_hash,
|
|
"move_model": self.move.move_model,
|
|
"move_models_bulk": self.move.move_models_bulk,
|
|
"auto_organize_models": self.auto_organize.auto_organize_models,
|
|
"get_auto_organize_progress": self.auto_organize.get_auto_organize_progress,
|
|
"get_model_notes": self.query.get_model_notes,
|
|
"get_model_preview_url": self.query.get_model_preview_url,
|
|
"get_model_civitai_url": self.query.get_model_civitai_url,
|
|
"get_model_metadata": self.query.get_model_metadata,
|
|
"get_model_description": self.query.get_model_description,
|
|
"get_relative_paths": self.query.get_relative_paths,
|
|
"refresh_model_updates": self.updates.refresh_model_updates,
|
|
"fetch_missing_civitai_license_data": self.updates.fetch_missing_civitai_license_data,
|
|
"set_model_update_ignore": self.updates.set_model_update_ignore,
|
|
"set_version_update_ignore": self.updates.set_version_update_ignore,
|
|
"get_model_update_status": self.updates.get_model_update_status,
|
|
"get_model_versions": self.updates.get_model_versions,
|
|
"cancel_task": self.query.cancel_task,
|
|
}
|