Files
ComfyUI-Lora-Manager/py/routes/handlers/model_handlers.py
Will Miao 76c15105e6 feat(lora-pool): add regex include/exclude name pattern filtering (#839)
Add name pattern filtering to LoRA Pool node allowing users to filter
LoRAs by filename or model name using either plain text or regex patterns.

Features:
- Include patterns: only show LoRAs matching at least one pattern
- Exclude patterns: exclude LoRAs matching any pattern
- Regex toggle: switch between substring and regex matching
- Case-insensitive matching for both modes
- Invalid regex automatically falls back to substring matching
- Filters apply to both file_name and model_name fields

Backend:
- Update LoraPoolLM._default_config() with namePatterns structure
- Add name pattern filtering to _apply_pool_filters() and _apply_specific_filters()
- Add API parameter parsing for name_pattern_include/exclude/use_regex
- Update LoraPoolConfig type with namePatterns field

Frontend:
- Add NamePatternsSection.vue component with pattern input UI
- Update useLoraPoolState to manage pattern state and API integration
- Update LoraPoolSummaryView to display NamePatternsSection
- Increase LORA_POOL_WIDGET_MIN_HEIGHT to accommodate new UI

Tests:
- Add 7 test cases covering text/regex include, exclude, combined
  filtering, model name fallback, and invalid regex handling

Closes #839
2026-03-19 17:15:05 +08:00

2443 lines
98 KiB
Python

"""Handlers for base model routes."""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
from aiohttp import web
import jinja2
from ...config import config
from ...services.download_coordinator import DownloadCoordinator
from ...services.metadata_sync_service import MetadataSyncService
from ...services.model_file_service import ModelMoveService
from ...services.preview_asset_service import PreviewAssetService
from ...services.settings_manager import SettingsManager, get_settings_manager
from ...services.tag_update_service import TagUpdateService
from ...services.use_cases import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
MetadataRefreshProgressReporter,
)
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...services.errors import RateLimitError, ResourceNotFoundError
from ...utils.civitai_utils import resolve_license_payload
from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
LICENSE_FIELDS = (
"allowNoCredit",
"allowCommercialUse",
"allowDerivatives",
"allowDifferentLicense",
)
class ModelPageView:
"""Render the HTML view for model listings."""
def __init__(
self,
*,
template_env: jinja2.Environment,
template_name: str,
service,
settings_service: SettingsManager,
server_i18n,
logger: logging.Logger,
) -> None:
self._template_env = template_env
self._template_name = template_name
self._service = service
self._settings = settings_service
self._server_i18n = server_i18n
self._logger = logger
self._app_version = self._get_app_version()
def _load_supporters(self) -> dict:
"""Load supporters data from JSON file."""
try:
current_file = os.path.abspath(__file__)
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
)
supporters_path = os.path.join(root_dir, "data", "supporters.json")
if os.path.exists(supporters_path):
with open(supporters_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._logger.debug(f"Failed to load supporters data: {e}")
return {"specialThanks": [], "allSupporters": [], "totalCount": 0}
def _get_app_version(self) -> str:
version = "1.0.0"
short_hash = "stable"
try:
import toml
current_file = os.path.abspath(__file__)
# Navigate up from py/routes/handlers/model_handlers.py to project root
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
)
pyproject_path = os.path.join(root_dir, "pyproject.toml")
if os.path.exists(pyproject_path):
with open(pyproject_path, "r", encoding="utf-8") as f:
data = toml.load(f)
version = (
data.get("project", {}).get("version", "1.0.0").replace("v", "")
)
# Try to get git info for granular cache busting
git_dir = os.path.join(root_dir, ".git")
if os.path.exists(git_dir):
try:
import git
repo = git.Repo(root_dir)
short_hash = repo.head.commit.hexsha[:7]
except Exception:
# Fallback if git is not available or not a repo
pass
except Exception as e:
self._logger.debug(f"Failed to read version info for cache busting: {e}")
return f"{version}-{short_hash}"
async def handle(self, request: web.Request) -> web.Response:
try:
is_initializing = (
self._service.scanner._cache is None
or (
hasattr(self._service.scanner, "is_initializing")
and callable(self._service.scanner.is_initializing)
and self._service.scanner.is_initializing()
)
or (
hasattr(self._service.scanner, "_is_initializing")
and self._service.scanner._is_initializing
)
)
if not self._template_env or not self._template_name:
return web.Response(
text="Template environment or template name not set",
status=500,
)
user_language = self._settings.get("language", "en")
self._server_i18n.set_locale(user_language)
if not hasattr(self._template_env, "_i18n_filter_added"):
self._template_env.filters["t"] = (
self._server_i18n.create_template_filter()
)
self._template_env._i18n_filter_added = True # type: ignore[attr-defined]
template_context = {
"is_initializing": is_initializing,
"settings": self._settings,
"request": request,
"folders": [],
"t": self._server_i18n.get_translation,
"version": self._app_version,
}
if not is_initializing:
try:
cache = await self._service.scanner.get_cached_data(
force_refresh=False
)
template_context["folders"] = getattr(cache, "folders", [])
except Exception as cache_error: # pragma: no cover - logging path
self._logger.error("Error loading cache data: %s", cache_error)
template_context["is_initializing"] = True
rendered = self._template_env.get_template(self._template_name).render(
**template_context
)
return web.Response(text=rendered, content_type="text/html")
except Exception as exc: # pragma: no cover - logging path
self._logger.error("Error handling models page: %s", exc, exc_info=True)
return web.Response(text="Error loading models page", status=500)
class ModelListingHandler:
"""Provide paginated model listings."""
def __init__(
self,
*,
service,
parse_specific_params: Callable[[web.Request], Dict],
logger: logging.Logger,
) -> None:
self._service = service
self._parse_specific_params = parse_specific_params
self._logger = logger
async def get_models(self, request: web.Request) -> web.Response:
start_time = time.perf_counter()
try:
params = self._parse_common_params(request)
result = await self._service.get_paginated_data(**params)
format_start = time.perf_counter()
formatted_result = {
"items": [
await self._service.format_response(item)
for item in result["items"]
],
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
"total_pages": result["total_pages"],
}
format_duration = time.perf_counter() - format_start
duration = time.perf_counter() - start_time
self._logger.debug(
"Request for %s/list took %.3fs (formatting: %.3fs)",
self._service.model_type,
duration,
format_duration,
)
return web.json_response(formatted_result)
except Exception as exc:
self._logger.error(
"Error retrieving %ss: %s", self._service.model_type, exc, exc_info=True
)
return web.json_response({"error": str(exc)}, status=500)
def _parse_common_params(self, request: web.Request) -> Dict:
page = int(request.query.get("page", "1"))
page_size = min(int(request.query.get("page_size", "20")), 100)
sort_by = request.query.get("sort_by", "name")
folder = request.query.get("folder")
folder_include = list(request.query.getall("folder_include", []))
search = request.query.get("search")
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", [])
folder_exclude = list(request.query.getall("folder_exclude", []))
model_types = list(request.query.getall("model_type", []))
model_types.extend(request.query.getall("civitai_model_type", []))
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", [])
if not legacy_tags:
legacy_csv = request.query.get("tags")
if legacy_csv:
legacy_tags = [
tag.strip() for tag in legacy_csv.split(",") if tag.strip()
]
include_tags = request.query.getall("tag_include", [])
exclude_tags = request.query.getall("tag_exclude", [])
tag_filters: Dict[str, str] = {}
for tag in legacy_tags:
if tag:
tag_filters[tag] = "include"
for tag in include_tags:
if tag:
tag_filters[tag] = "include"
for tag in exclude_tags:
if tag:
tag_filters[tag] = "exclude"
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
search_options = {
"filename": request.query.get("search_filename", "true").lower() == "true",
"modelname": request.query.get("search_modelname", "true").lower()
== "true",
"tags": request.query.get("search_tags", "false").lower() == "true",
"creator": request.query.get("search_creator", "false").lower() == "true",
"recursive": request.query.get("recursive", "true").lower() == "true",
}
hash_filters: Dict[str, object] = {}
if "hash" in request.query:
hash_filters["single_hash"] = request.query["hash"]
elif "hashes" in request.query:
try:
hash_list = json.loads(request.query["hashes"])
if isinstance(hash_list, list):
hash_filters["multiple_hashes"] = hash_list
except (json.JSONDecodeError, TypeError):
pass
update_available_only = (
request.query.get("update_available_only", "false").lower() == "true"
)
# Tag logic: "any" (OR) or "all" (AND) for include tags
tag_logic = request.query.get("tag_logic", "any").lower()
if tag_logic not in ("any", "all"):
tag_logic = "any"
# New license-based query filters
credit_required = request.query.get("credit_required")
if credit_required is not None:
credit_required = credit_required.lower() not in ("false", "0", "")
else:
credit_required = None # None means no filter applied
allow_selling_generated_content = request.query.get(
"allow_selling_generated_content"
)
if allow_selling_generated_content is not None:
allow_selling_generated_content = (
allow_selling_generated_content.lower() not in ("false", "0", "")
)
else:
allow_selling_generated_content = None # None means no filter applied
# Name pattern filters for LoRA Pool
name_pattern_include = request.query.getall("name_pattern_include", [])
name_pattern_exclude = request.query.getall("name_pattern_exclude", [])
name_pattern_use_regex = (
request.query.get("name_pattern_use_regex", "false").lower() == "true"
)
return {
"page": page,
"page_size": page_size,
"sort_by": sort_by,
"folder": folder,
"folder_include": folder_include,
"folder_exclude": folder_exclude,
"search": search,
"fuzzy_search": fuzzy_search,
"base_models": base_models,
"tags": tag_filters,
"tag_logic": tag_logic,
"search_options": search_options,
"hash_filters": hash_filters,
"favorites_only": favorites_only,
"update_available_only": update_available_only,
"credit_required": credit_required,
"allow_selling_generated_content": allow_selling_generated_content,
"model_types": model_types,
"name_pattern_include": name_pattern_include,
"name_pattern_exclude": name_pattern_exclude,
"name_pattern_use_regex": name_pattern_use_regex,
**self._parse_specific_params(request),
}
class ModelManagementHandler:
"""Handle mutation operations on models."""
def __init__(
self,
*,
service,
logger: logging.Logger,
metadata_sync: MetadataSyncService,
preview_service: PreviewAssetService,
tag_update_service: TagUpdateService,
lifecycle_service,
) -> None:
self._service = service
self._logger = logger
self._metadata_sync = metadata_sync
self._preview_service = preview_service
self._tag_update_service = tag_update_service
self._lifecycle_service = lifecycle_service
async def delete_model(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.Response(text="Model path is required", status=400)
result = await self._lifecycle_service.delete_model(file_path)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error deleting model: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def exclude_model(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.Response(text="Model path is required", status=400)
result = await self._lifecycle_service.exclude_model(file_path)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error excluding model: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def fetch_civitai(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.json_response(
{"success": False, "error": "File path is required"}, status=400
)
cache = await self._service.scanner.get_cached_data()
model_data = next(
(item for item in cache.raw_data if item["file_path"] == file_path),
None,
)
if not model_data:
return web.json_response(
{"success": False, "error": "Model not found in cache"}, status=404
)
# Check if hash needs to be calculated (lazy hash for checkpoints)
sha256 = model_data.get("sha256")
hash_status = model_data.get("hash_status", "completed")
if not sha256 or hash_status != "completed":
# For checkpoints, calculate hash on-demand
scanner = self._service.scanner
if hasattr(scanner, "calculate_hash_for_model"):
self._logger.info(
f"Lazy hash calculation triggered for {file_path}"
)
sha256 = await scanner.calculate_hash_for_model(file_path)
if not sha256:
return web.json_response(
{
"success": False,
"error": "Failed to calculate SHA256 hash",
},
status=500,
)
# Update model_data with new hash
model_data["sha256"] = sha256
model_data["hash_status"] = "completed"
else:
return web.json_response(
{"success": False, "error": "No SHA256 hash found"}, status=400
)
await MetadataManager.hydrate_model_data(model_data)
success, error = await self._metadata_sync.fetch_and_update_model(
sha256=model_data["sha256"],
file_path=file_path,
model_data=model_data,
update_cache_func=self._service.scanner.update_single_model_cache,
)
if not success:
return web.json_response({"success": False, "error": error})
formatted_metadata = await self._service.format_response(model_data)
return web.json_response({"success": True, "metadata": formatted_metadata})
except Exception as exc:
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def relink_civitai(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
model_id = data.get("model_id")
model_version_id = data.get("model_version_id")
if not file_path or model_id is None:
return web.json_response(
{
"success": False,
"error": "Both file_path and model_id are required",
},
status=400,
)
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
local_metadata = await self._metadata_sync.load_local_metadata(
metadata_path
)
updated_metadata = await self._metadata_sync.relink_metadata(
file_path=file_path,
metadata=local_metadata,
model_id=int(model_id),
model_version_id=int(model_version_id) if model_version_id else None,
)
await self._service.scanner.update_single_model_cache(
file_path, file_path, updated_metadata
)
message = f"Model successfully re-linked to Civitai model {model_id}" + (
f" version {model_version_id}" if model_version_id else ""
)
return web.json_response(
{
"success": True,
"message": message,
"hash": updated_metadata.get("sha256", ""),
}
)
except Exception as exc:
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def replace_preview(self, request: web.Request) -> web.Response:
try:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "preview_file":
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get("Content-Type", "image/png")
content_disposition = field.headers.get("Content-Disposition", "")
original_filename = None
import re
match = re.search(r'filename="(.*?)"', content_disposition)
if match:
original_filename = match.group(1)
preview_data = await field.read()
field = await reader.next()
if field is None or field.name != "model_path":
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
nsfw_level = 0
field = await reader.next()
if field and field.name == "nsfw_level":
try:
nsfw_level = int((await field.read()).decode())
except (ValueError, TypeError):
self._logger.warning("Invalid NSFW level format, using default 0")
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_data,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(
result["preview_path"]
),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def set_preview_from_url(self, request: web.Request) -> web.Response:
"""Set a preview image from a remote URL (e.g., CivitAI)."""
try:
from ...utils.civitai_utils import rewrite_preview_url
from ...services.downloader import get_downloader
data = await request.json()
model_path = data.get("model_path")
image_url = data.get("image_url")
nsfw_level = data.get("nsfw_level", 0)
if not model_path:
return web.json_response(
{"success": False, "error": "Model path is required"}, status=400
)
if not image_url:
return web.json_response(
{"success": False, "error": "Image URL is required"}, status=400
)
# Rewrite URL to use optimized rendition if it's a Civitai URL
optimized_url, was_rewritten = rewrite_preview_url(
image_url, media_type="image"
)
if was_rewritten and optimized_url:
self._logger.info(
f"Rewritten preview URL to optimized version: {optimized_url}"
)
else:
optimized_url = image_url
# Download the image using the Downloader service
self._logger.info(
f"Downloading preview from {optimized_url} for {model_path}"
)
downloader = await get_downloader()
success, preview_data, headers = await downloader.download_to_memory(
optimized_url, use_auth=False, return_headers=True
)
if not success:
return web.json_response(
{
"success": False,
"error": f"Failed to download image: {preview_data}",
},
status=502,
)
# preview_data is bytes when success is True
preview_bytes = (
preview_data
if isinstance(preview_data, bytes)
else preview_data.encode("utf-8")
)
# Determine content type from response headers
content_type = (
headers.get("Content-Type", "image/jpeg") if headers else "image/jpeg"
)
# Extract original filename from URL
original_filename = None
if "?" in image_url:
url_path = image_url.split("?")[0]
else:
url_path = image_url
original_filename = url_path.split("/")[-1] if "/" in url_path else None
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_data,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(
result["preview_path"]
),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
if not image_url:
return web.json_response(
{"success": False, "error": "Image URL is required"}, status=400
)
# Download the image from the remote URL
self._logger.info(f"Downloading preview from {image_url} for {model_path}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
return web.json_response(
{
"success": False,
"error": f"Failed to download image: HTTP {response.status}",
},
status=502,
)
content_type = response.headers.get("Content-Type", "image/jpeg")
preview_data = await response.read()
# Extract original filename from URL
original_filename = None
if "?" in image_url:
url_path = image_url.split("?")[0]
else:
url_path = image_url
original_filename = (
url_path.split("/")[-1] if "/" in url_path else None
)
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_bytes,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(
result["preview_path"]
),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.Response(text="File path is required", status=400)
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
await self._metadata_sync.save_metadata_updates(
file_path=file_path,
updates=metadata_updates,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
if "model_name" in metadata_updates:
cache = await self._service.scanner.get_cached_data()
await cache.resort()
return web.json_response({"success": True})
except Exception as exc:
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def add_tags(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
new_tags = data.get("tags", [])
if not file_path:
return web.Response(text="File path is required", status=400)
if not isinstance(new_tags, list):
return web.Response(text="Tags must be a list", status=400)
tags = await self._tag_update_service.add_tags(
file_path=file_path,
new_tags=new_tags,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, "tags": tags})
except Exception as exc:
self._logger.error("Error adding tags: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def rename_model(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
new_file_name = data.get("new_file_name")
if not file_path or not new_file_name:
return web.json_response(
{
"success": False,
"error": "File path and new file name are required",
},
status=400,
)
result = await self._lifecycle_service.rename_model(
file_path=file_path, new_file_name=new_file_name
)
return web.json_response(
{
**result,
"new_preview_path": config.get_preview_static_url(
result.get("new_preview_path")
),
}
)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error renaming model: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def bulk_delete_models(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_paths = data.get("file_paths", [])
if not file_paths:
return web.json_response(
{
"success": False,
"error": "No file paths provided for deletion",
},
status=400,
)
result = await self._lifecycle_service.bulk_delete_models(file_paths)
return web.json_response(result)
except ValueError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error in bulk delete: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def verify_duplicates(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_paths = data.get("file_paths", [])
if not file_paths:
return web.json_response(
{
"success": False,
"error": "No file paths provided for verification",
},
status=400,
)
results = await self._metadata_sync.verify_duplicate_hashes(
file_paths=file_paths,
metadata_loader=self._metadata_sync.load_local_metadata,
hash_calculator=calculate_sha256,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, **results})
except Exception as exc:
self._logger.error(
"Error verifying duplicate models: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelQueryHandler:
"""Serve read-only model queries."""
def __init__(self, *, service, logger: logging.Logger) -> None:
self._service = service
self._logger = logger
async def get_top_tags(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 0:
limit = 20
top_tags = await self._service.get_top_tags(limit)
return web.json_response({"success": True, "tags": top_tags})
except Exception as exc:
self._logger.error("Error getting top tags: %s", exc, exc_info=True)
return web.json_response(
{"success": False, "error": "Internal server error"}, status=500
)
async def get_base_models(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 1 or limit > 100:
limit = 20
base_models = await self._service.get_base_models(limit)
return web.json_response({"success": True, "base_models": base_models})
except Exception as exc:
self._logger.error("Error retrieving base models: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_types(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 1 or limit > 100:
limit = 20
model_types = await self._service.get_model_types(limit)
return web.json_response({"success": True, "model_types": model_types})
except Exception as exc:
self._logger.error("Error retrieving model types: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_models(self, request: web.Request) -> web.Response:
try:
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
await self._service.scan_models(
force_refresh=True, rebuild_cache=full_rebuild
)
if self._service.scanner.is_cancelled():
return web.json_response(
{
"status": "cancelled",
"message": f"{self._service.model_type.capitalize()} scan cancelled",
}
)
return web.json_response(
{
"status": "success",
"message": f"{self._service.model_type.capitalize()} scan completed",
}
)
except Exception as exc:
self._logger.error(
"Error scanning %ss: %s", self._service.model_type, exc, exc_info=True
)
return web.json_response({"error": str(exc)}, status=500)
async def get_model_roots(self, request: web.Request) -> web.Response:
try:
roots = self._service.get_model_roots()
return web.json_response({"success": True, "roots": roots})
except Exception as exc:
self._logger.error(
"Error getting %s roots: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_folders(self, request: web.Request) -> web.Response:
try:
cache = await self._service.scanner.get_cached_data()
return web.json_response({"folders": cache.folders})
except Exception as exc:
self._logger.error("Error getting folders: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def cancel_task(self, request: web.Request) -> web.Response:
try:
self._service.scanner.cancel_task()
return web.json_response(
{"status": "success", "message": "Cancellation requested"}
)
except Exception as exc:
self._logger.error(
"Error cancelling task for %s: %s", self._service.model_type, exc
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_folder_tree(self, request: web.Request) -> web.Response:
try:
model_root = request.query.get("model_root")
if not model_root:
return web.json_response(
{"success": False, "error": "model_root parameter is required"},
status=400,
)
folder_tree = await self._service.get_folder_tree(model_root)
return web.json_response({"success": True, "tree": folder_tree})
except Exception as exc:
self._logger.error("Error getting folder tree: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_unified_folder_tree(self, request: web.Request) -> web.Response:
try:
unified_tree = await self._service.get_unified_folder_tree()
return web.json_response({"success": True, "tree": unified_tree})
except Exception as exc:
self._logger.error("Error getting unified folder tree: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def find_duplicate_models(self, request: web.Request) -> web.Response:
try:
filters = self._parse_duplicate_filters(request)
duplicates = self._service.find_duplicate_hashes()
result = []
cache = await self._service.scanner.get_cached_data()
for sha256, paths in duplicates.items():
# Collect all models in this group
all_models = []
for path in paths:
model = next(
(m for m in cache.raw_data if m["file_path"] == path), None
)
if model:
all_models.append(model)
# Include primary if not already in paths
primary_path = self._service.get_path_by_hash(sha256)
if primary_path and primary_path not in paths:
primary_model = next(
(m for m in cache.raw_data if m["file_path"] == primary_path),
None,
)
if primary_model:
all_models.insert(0, primary_model)
# Apply filters
filtered = self._apply_duplicate_filters(all_models, filters)
# Sort: originals first, copies last
sorted_models = self._sort_duplicate_group(filtered)
# Format response
group = {"hash": sha256, "models": []}
for model in sorted_models:
group["models"].append(await self._service.format_response(model))
# Only include groups with 2+ models after filtering
if len(group["models"]) > 1:
result.append(group)
return web.json_response(
{"success": True, "duplicates": result, "count": len(result)}
)
except Exception as exc:
self._logger.error(
"Error finding duplicate %ss: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
def _parse_duplicate_filters(self, request: web.Request) -> Dict[str, Any]:
"""Parse filter parameters from the request for duplicate finding."""
return {
"base_models": request.query.getall("base_model", []),
"tag_include": request.query.getall("tag_include", []),
"tag_exclude": request.query.getall("tag_exclude", []),
"model_types": request.query.getall("model_type", []),
"folder": request.query.get("folder"),
"favorites_only": request.query.get("favorites_only", "").lower() == "true",
}
def _apply_duplicate_filters(
self, models: List[Dict[str, Any]], filters: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Apply filters to a list of models within a duplicate group."""
result = models
# Apply base model filter
if filters.get("base_models"):
base_set = set(filters["base_models"])
result = [m for m in result if m.get("base_model") in base_set]
# Apply tag filters (include)
for tag in filters.get("tag_include", []):
if tag == "__no_tags__":
result = [m for m in result if not m.get("tags")]
else:
result = [m for m in result if tag in (m.get("tags") or [])]
# Apply tag filters (exclude)
for tag in filters.get("tag_exclude", []):
if tag == "__no_tags__":
result = [m for m in result if m.get("tags")]
else:
result = [m for m in result if tag not in (m.get("tags") or [])]
# Apply model type filter
if filters.get("model_types"):
type_set = {t.lower() for t in filters["model_types"]}
result = [
m for m in result if (m.get("model_type") or "").lower() in type_set
]
# Apply folder filter
if filters.get("folder"):
folder = filters["folder"]
result = [m for m in result if m.get("folder", "").startswith(folder)]
# Apply favorites filter
if filters.get("favorites_only"):
result = [m for m in result if m.get("favorite", False)]
return result
def _sort_duplicate_group(
self, models: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Sort models: originals first (left), copies (with -????. pattern) last (right)."""
if len(models) <= 1:
return models
min_len = min(len(m.get("file_name", "")) for m in models)
def copy_score(m):
fn = m.get("file_name", "")
score = 0
# Match -0001.safetensors, -1234.safetensors etc.
if re.search(r"-\d{4}\.", fn):
score += 100
# Match (1), (2) etc.
if re.search(r"\(\d+\)", fn):
score += 50
# Match 'copy' in filename
if "copy" in fn.lower():
score += 50
# Longer filenames are more likely copies
score += len(fn) - min_len
return (score, fn.lower())
return sorted(models, key=copy_score)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
try:
duplicates = self._service.find_duplicate_filenames()
result = []
cache = await self._service.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {"filename": filename, "models": []}
for path in paths:
model = next(
(m for m in cache.raw_data if m["file_path"] == path), None
)
if model:
group["models"].append(
await self._service.format_response(model)
)
hash_val = self._service.scanner.get_hash_by_filename(filename)
if hash_val:
main_path = self._service.get_path_by_hash(hash_val)
if main_path and main_path not in paths:
main_model = next(
(m for m in cache.raw_data if m["file_path"] == main_path),
None,
)
if main_model:
group["models"].insert(
0, await self._service.format_response(main_model)
)
if group["models"]:
result.append(group)
return web.json_response(
{"success": True, "conflicts": result, "count": len(result)}
)
except Exception as exc:
self._logger.error(
"Error finding filename conflicts for %ss: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_notes(self, request: web.Request) -> web.Response:
try:
model_name = request.query.get("name")
if not model_name:
return web.Response(
text=f"{self._service.model_type.capitalize()} file name is required",
status=400,
)
notes = await self._service.get_model_notes(model_name)
if notes is not None:
return web.json_response({"success": True, "notes": notes})
return web.json_response(
{
"success": False,
"error": f"{self._service.model_type.capitalize()} not found in cache",
},
status=404,
)
except Exception as exc:
self._logger.error(
"Error getting %s notes: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_preview_url(self, request: web.Request) -> web.Response:
try:
model_name = request.query.get("name")
if not model_name:
return web.Response(
text=f"{self._service.model_type.capitalize()} file name is required",
status=400,
)
include_license_flags = request.query.get(
"license_flags", ""
).strip().lower() in {"1", "true", "yes", "on"}
preview_url = await self._service.get_model_preview_url(model_name)
if preview_url:
response_payload: dict[str, object] = {
"success": True,
"preview_url": preview_url,
}
if include_license_flags:
model_data = await self._service.get_model_info_by_name(model_name)
license_flags = (model_data or {}).get("license_flags")
if license_flags is not None:
response_payload["license_flags"] = int(license_flags)
return web.json_response(response_payload)
return web.json_response(
{
"success": False,
"error": f"No preview URL found for the specified {self._service.model_type}",
},
status=404,
)
except Exception as exc:
self._logger.error(
"Error getting %s preview URL: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_civitai_url(self, request: web.Request) -> web.Response:
try:
model_name = request.query.get("name")
if not model_name:
return web.Response(
text=f"{self._service.model_type.capitalize()} file name is required",
status=400,
)
result = await self._service.get_model_civitai_url(model_name)
if result["civitai_url"]:
return web.json_response({"success": True, **result})
return web.json_response(
{
"success": False,
"error": f"No Civitai data found for the specified {self._service.model_type}",
},
status=404,
)
except Exception as exc:
self._logger.error(
"Error getting %s Civitai URL: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_metadata(self, request: web.Request) -> web.Response:
try:
file_path = request.query.get("file_path")
if not file_path:
return web.Response(text="File path is required", status=400)
metadata = await self._service.get_model_metadata(file_path)
if metadata is not None:
return web.json_response({"success": True, "metadata": metadata})
return web.json_response(
{
"success": False,
"error": f"{self._service.model_type.capitalize()} not found or no CivitAI metadata available",
},
status=404,
)
except Exception as exc:
self._logger.error(
"Error getting %s metadata: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_description(self, request: web.Request) -> web.Response:
try:
file_path = request.query.get("file_path")
if not file_path:
return web.Response(text="File path is required", status=400)
description = await self._service.get_model_description(file_path)
if description is not None:
return web.json_response({"success": True, "description": description})
return web.json_response(
{
"success": False,
"error": f"{self._service.model_type.capitalize()} not found or no description available",
},
status=404,
)
except Exception as exc:
self._logger.error(
"Error getting %s description: %s",
self._service.model_type,
exc,
exc_info=True,
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_relative_paths(self, request: web.Request) -> web.Response:
try:
search = request.query.get("search", "").strip()
limit = min(int(request.query.get("limit", "15")), 100)
offset = max(0, int(request.query.get("offset", "0")))
matching_paths = await self._service.search_relative_paths(
search, limit, offset
)
return web.json_response(
{"success": True, "relative_paths": matching_paths}
)
except Exception as exc:
self._logger.error(
"Error getting relative paths for autocomplete: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelDownloadHandler:
"""Coordinate downloads and progress reporting."""
def __init__(
self,
*,
ws_manager: WebSocketManager,
logger: logging.Logger,
download_use_case: DownloadModelUseCase,
download_coordinator: DownloadCoordinator,
) -> None:
self._ws_manager = ws_manager
self._logger = logger
self._download_use_case = download_use_case
self._download_coordinator = download_coordinator
async def download_model(self, request: web.Request) -> web.Response:
try:
payload = await request.json()
result = await self._download_use_case.execute(payload)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except DownloadModelValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except DownloadModelEarlyAccessError as exc:
self._logger.warning("Early access error: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=401)
except Exception as exc:
error_message = str(exc)
self._logger.error(
"Error downloading model: %s", error_message, exc_info=True
)
return web.json_response(
{"success": False, "error": error_message}, status=500
)
async def download_model_get(self, request: web.Request) -> web.Response:
try:
model_id = request.query.get("model_id")
if not model_id:
return web.Response(
status=400,
text="Missing required parameter: Please provide 'model_id'",
)
model_version_id = request.query.get("model_version_id")
download_id = request.query.get("download_id")
use_default_paths = (
request.query.get("use_default_paths", "false").lower() == "true"
)
source = request.query.get("source")
file_params_json = request.query.get("file_params")
data = {"model_id": model_id, "use_default_paths": use_default_paths}
if model_version_id:
data["model_version_id"] = model_version_id
if download_id:
data["download_id"] = download_id
if source:
data["source"] = source
if file_params_json:
import json
try:
data["file_params"] = json.loads(file_params_json)
except json.JSONDecodeError:
self._logger.warning(
"Invalid file_params JSON: %s", file_params_json
)
loop = asyncio.get_event_loop()
future = loop.create_future()
future.set_result(data)
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
result = await self._download_use_case.execute(data)
if not result.get("success", False):
return web.json_response(result, status=500)
return web.json_response(result)
except DownloadModelValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except DownloadModelEarlyAccessError as exc:
self._logger.warning("Early access error: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=401)
except Exception as exc:
self._logger.error(
"Error downloading model via GET: %s", exc, exc_info=True
)
return web.Response(status=500, text=str(exc))
async def cancel_download_get(self, request: web.Request) -> web.Response:
try:
download_id = request.query.get("download_id")
if not download_id:
return web.json_response(
{"success": False, "error": "Download ID is required"}, status=400
)
result = await self._download_coordinator.cancel_download(download_id)
return web.json_response(result)
except Exception as exc:
self._logger.error(
"Error cancelling download via GET: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def pause_download_get(self, request: web.Request) -> web.Response:
try:
download_id = request.query.get("download_id")
if not download_id:
return web.json_response(
{"success": False, "error": "Download ID is required"}, status=400
)
result = await self._download_coordinator.pause_download(download_id)
status = 200 if result.get("success") else 400
return web.json_response(result, status=status)
except Exception as exc:
self._logger.error("Error pausing download via GET: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def resume_download_get(self, request: web.Request) -> web.Response:
try:
download_id = request.query.get("download_id")
if not download_id:
return web.json_response(
{"success": False, "error": "Download ID is required"}, status=400
)
result = await self._download_coordinator.resume_download(download_id)
status = 200 if result.get("success") else 400
return web.json_response(result, status=status)
except Exception as exc:
self._logger.error(
"Error resuming download via GET: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_download_progress(self, request: web.Request) -> web.Response:
try:
download_id = request.match_info.get("download_id")
if not download_id:
return web.json_response(
{"success": False, "error": "Download ID is required"}, status=400
)
progress_data = self._ws_manager.get_download_progress(download_id)
if progress_data is None:
return web.json_response(
{"success": False, "error": "Download ID not found"}, status=404
)
response_payload = {
"success": True,
"progress": progress_data.get("progress", 0),
"bytes_downloaded": progress_data.get("bytes_downloaded"),
"total_bytes": progress_data.get("total_bytes"),
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
}
status = progress_data.get("status")
if status and status != "progress":
response_payload["status"] = status
if "message" in progress_data:
response_payload["message"] = progress_data["message"]
elif status is None and "message" in progress_data:
response_payload["message"] = progress_data["message"]
return web.json_response(response_payload)
except Exception as exc:
self._logger.error(
"Error getting download progress: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelCivitaiHandler:
"""CivitAI integration endpoints."""
def __init__(
self,
*,
service,
settings_service: SettingsManager,
ws_manager: WebSocketManager,
logger: logging.Logger,
metadata_provider_factory: Callable[[], Awaitable],
validate_model_type: Callable[[str], bool],
expected_model_types: Callable[[], str],
find_model_file: Callable[
[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]
],
metadata_sync: MetadataSyncService,
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
metadata_progress_callback: MetadataRefreshProgressReporter,
) -> None:
self._service = service
self._settings = settings_service
self._ws_manager = ws_manager
self._logger = logger
self._metadata_provider_factory = metadata_provider_factory
self._validate_model_type = validate_model_type
self._expected_model_types = expected_model_types
self._find_model_file = find_model_file
self._metadata_sync = metadata_sync
self._metadata_refresh_use_case = metadata_refresh_use_case
self._metadata_progress_callback = metadata_progress_callback
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
try:
result = await self._metadata_refresh_use_case.execute_with_error_handling(
progress_callback=self._metadata_progress_callback
)
return web.json_response(result)
except Exception as exc:
self._logger.error(
"Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc
)
return web.Response(text=str(exc), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
try:
model_id = request.match_info["model_id"]
metadata_provider = await self._metadata_provider_factory()
try:
response = await metadata_provider.get_model_versions(model_id)
except ResourceNotFoundError:
return web.Response(status=404, text="Model not found")
if not response or not response.get("modelVersions"):
return web.Response(status=404, text="Model not found")
versions = response.get("modelVersions", [])
model_type = response.get("type", "")
if not self._validate_model_type(model_type):
return web.json_response(
{
"error": f"Model type mismatch. Expected {self._expected_model_types()}, got {model_type}"
},
status=400,
)
cache = await self._service.scanner.get_cached_data()
version_index = cache.version_index
for version in versions:
version_id = None
version_id_raw = version.get("id")
if version_id_raw is not None:
try:
version_id = int(str(version_id_raw))
except (TypeError, ValueError):
version_id = None
cache_entry = (
version_index.get(version_id)
if (version_id is not None and version_index)
else None
)
version["existsLocally"] = cache_entry is not None
if cache_entry and isinstance(cache_entry, Mapping):
local_path = cache_entry.get("file_path")
if local_path:
version["localPath"] = local_path
else:
version.pop("localPath", None)
model_file = (
self._find_model_file(version.get("files", []))
if isinstance(version.get("files"), Iterable)
else None
)
if model_file and isinstance(model_file, Mapping):
version["modelSizeKB"] = model_file.get("sizeKB")
return web.json_response(versions)
except Exception as exc:
self._logger.error(
"Error fetching %s model versions: %s", self._service.model_type, exc
)
return web.Response(status=500, text=str(exc))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
try:
model_version_id = request.match_info.get("modelVersionId")
metadata_provider = await self._metadata_provider_factory()
model, error_msg = await metadata_provider.get_model_version_info(
model_version_id
)
if not model:
self._logger.warning(
"Failed to fetch model version %s: %s", model_version_id, error_msg
)
status_code = (
404 if error_msg and "not found" in error_msg.lower() else 500
)
return web.json_response(
{
"success": False,
"error": error_msg or "Failed to fetch model information",
},
status=status_code,
)
return web.json_response(model)
except Exception as exc:
self._logger.error("Error fetching model details: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
try:
hash_value = request.match_info.get("hash")
metadata_provider = await self._metadata_provider_factory()
model, error = await metadata_provider.get_model_by_hash(hash_value)
if error:
self._logger.warning("Error getting model by hash: %s", error)
return web.json_response({"success": False, "error": error}, status=404)
return web.json_response(model)
except Exception as exc:
self._logger.error("Error fetching model details by hash: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelMoveHandler:
"""Move model files between folders."""
def __init__(
self, *, move_service: ModelMoveService, logger: logging.Logger
) -> None:
self._move_service = move_service
self._logger = logger
async def move_model(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
target_path = data.get("target_path")
use_default_paths = data.get("use_default_paths", False)
if not file_path or not target_path:
return web.Response(
text="File path and target path are required", status=400
)
result = await self._move_service.move_model(
file_path, target_path, use_default_paths=use_default_paths
)
status = 200 if result.get("success") else 500
return web.json_response(result, status=status)
except Exception as exc:
self._logger.error("Error moving model: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def move_models_bulk(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_paths = data.get("file_paths", [])
target_path = data.get("target_path")
use_default_paths = data.get("use_default_paths", False)
if not file_paths or not target_path:
return web.Response(
text="File paths and target path are required", status=400
)
result = await self._move_service.move_models_bulk(
file_paths, target_path, use_default_paths=use_default_paths
)
return web.json_response(result)
except Exception as exc:
self._logger.error("Error moving models in bulk: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
class ModelAutoOrganizeHandler:
"""Manage auto-organize operations."""
def __init__(
self,
*,
use_case: AutoOrganizeUseCase,
progress_callback: WebSocketProgressCallback,
ws_manager: WebSocketManager,
logger: logging.Logger,
) -> None:
self._use_case = use_case
self._progress_callback = progress_callback
self._ws_manager = ws_manager
self._logger = logger
async def auto_organize_models(self, request: web.Request) -> web.Response:
try:
file_paths = None
exclusion_patterns = None
settings_manager = get_settings_manager()
if request.method == "POST":
try:
data = await request.json()
file_paths = data.get("file_paths")
if "exclusion_patterns" in data:
exclusion_patterns = (
settings_manager.normalize_auto_organize_exclusions(
data.get("exclusion_patterns")
)
)
except Exception: # pragma: no cover - permissive path
pass
result = await self._use_case.execute(
file_paths=file_paths,
progress_callback=self._progress_callback,
exclusion_patterns=exclusion_patterns,
)
return web.json_response(result.to_dict())
except AutoOrganizeInProgressError:
return web.json_response(
{
"success": False,
"error": "Auto-organize is already running. Please wait for it to complete.",
},
status=409,
)
except Exception as exc:
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
try:
await self._progress_callback.on_progress(
{
"type": "auto_organize_progress",
"status": "error",
"error": str(exc),
}
)
except Exception: # pragma: no cover - defensive reporting
pass
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:
try:
progress_data = self._ws_manager.get_auto_organize_progress()
if progress_data is None:
return web.json_response(
{
"success": False,
"error": "No auto-organize operation in progress",
},
status=404,
)
return web.json_response({"success": True, "progress": progress_data})
except Exception as exc:
self._logger.error(
"Error getting auto-organize progress: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelUpdateHandler:
"""Handle update tracking requests."""
def __init__(
self,
*,
service,
update_service,
metadata_provider_selector,
settings_service,
logger: logging.Logger,
) -> None:
self._service = service
self._update_service = update_service
self._metadata_provider_selector = metadata_provider_selector
self._settings = settings_service
self._logger = logger
async def fetch_missing_civitai_license_data(
self, request: web.Request
) -> web.Response:
payload = await self._read_json(request)
target_model_ids = self._extract_target_model_ids(payload)
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
{"success": False, "error": "Civitai provider not available"},
status=503,
)
try:
cache = await self._service.scanner.get_cached_data()
except Exception as exc:
self._logger.error(
"Failed to load cache for license refresh: %s", exc, exc_info=True
)
cache = None
target_set = set(target_model_ids) if target_model_ids is not None else None
candidates = await self._collect_models_missing_license(cache, target_set)
if not candidates:
return web.json_response({"success": True, "updated": []})
model_ids = sorted(candidates.keys())
try:
license_map = await self._fetch_license_info(provider, model_ids)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"},
status=429,
)
except Exception as exc: # pragma: no cover - defensive log
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
updated: List[Dict[str, str]] = []
errors: List[Dict[str, str]] = []
for model_id in model_ids:
license_payload = license_map.get(model_id)
if not license_payload:
continue
resolved_payload = resolve_license_payload(license_payload)
for context in candidates.get(model_id, []):
metadata_path = context["file_path"]
metadata_payload = context["metadata"]
civitai_section = metadata_payload.setdefault("civitai", {})
model_section = civitai_section.get("model")
if not isinstance(model_section, Mapping):
model_section = {}
model_section.update(resolved_payload)
civitai_section["model"] = model_section
metadata_payload["civitai"] = civitai_section
try:
await MetadataManager.save_metadata(metadata_path, metadata_payload)
updated.append({"modelId": model_id, "filePath": metadata_path})
except Exception as exc:
self._logger.error(
"Failed to save metadata for %s: %s",
metadata_path,
exc,
exc_info=True,
)
errors.append({"filePath": metadata_path, "error": str(exc)})
response_payload = {"success": True, "updated": updated}
missing_model_ids = [mid for mid in model_ids if mid not in license_map]
if missing_model_ids:
response_payload["missingModelIds"] = missing_model_ids
if errors:
response_payload["errors"] = errors
return web.json_response(response_payload)
async def refresh_model_updates(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
force_refresh = self._parse_bool(
request.query.get("force")
) or self._parse_bool(payload.get("force"))
raw_model_ids = payload.get("modelIds")
if raw_model_ids is None:
raw_model_ids = payload.get("model_ids")
target_model_ids: list[int] = []
if isinstance(raw_model_ids, (list, tuple, set)):
for value in raw_model_ids:
normalized = self._normalize_model_id(value)
if normalized is not None:
target_model_ids.append(normalized)
if target_model_ids:
target_model_ids = sorted(set(target_model_ids))
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
{"success": False, "error": "Civitai provider not available"},
status=503,
)
try:
records = await self._update_service.refresh_for_model_type(
self._service.model_type,
self._service.scanner,
provider,
force_refresh=force_refresh,
target_model_ids=target_model_ids or None,
)
if self._service.scanner.is_cancelled():
return web.json_response(
{
"success": False,
"status": "cancelled",
"message": "Update refresh cancelled",
}
)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
except Exception as exc: # pragma: no cover - defensive logging
self._logger.error(
"Failed to refresh model updates: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
serialized_records = []
for record in records.values():
has_update_fn = getattr(record, "has_update", None)
if callable(has_update_fn) and has_update_fn():
serialized_records.append(self._serialize_record(record))
return web.json_response(
{
"success": True,
"records": serialized_records,
}
)
async def set_model_update_ignore(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
model_id = self._normalize_model_id(payload.get("modelId"))
if model_id is None:
return web.json_response(
{"success": False, "error": "modelId is required"}, status=400
)
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
record = await self._update_service.set_should_ignore(
self._service.model_type, model_id, should_ignore
)
return web.json_response(
{"success": True, "record": self._serialize_record(record)}
)
async def set_version_update_ignore(self, request: web.Request) -> web.Response:
payload = await self._read_json(request)
model_id = self._normalize_model_id(payload.get("modelId"))
version_id = self._normalize_model_id(payload.get("versionId"))
if model_id is None or version_id is None:
return web.json_response(
{"success": False, "error": "modelId and versionId are required"},
status=400,
)
should_ignore = self._parse_bool(payload.get("shouldIgnore"))
record = await self._update_service.set_version_should_ignore(
self._service.model_type,
model_id,
version_id,
should_ignore,
)
overrides = await self._build_version_context(record)
return web.json_response(
{
"success": True,
"record": self._serialize_record(record, version_context=overrides),
}
)
async def get_model_update_status(self, request: web.Request) -> web.Response:
model_id = self._normalize_model_id(request.match_info.get("model_id"))
if model_id is None:
return web.json_response(
{"success": False, "error": "model_id must be an integer"}, status=400
)
refresh = self._parse_bool(request.query.get("refresh"))
force = self._parse_bool(request.query.get("force"))
try:
record = await self._get_or_refresh_record(
model_id, refresh=refresh, force=force
)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
if record is None:
return web.json_response(
{"success": False, "error": "Model not tracked"}, status=404
)
return web.json_response(
{"success": True, "record": self._serialize_record(record)}
)
async def get_model_versions(self, request: web.Request) -> web.Response:
model_id = self._normalize_model_id(request.match_info.get("model_id"))
if model_id is None:
return web.json_response(
{"success": False, "error": "model_id must be an integer"}, status=400
)
refresh = self._parse_bool(request.query.get("refresh"))
force = self._parse_bool(request.query.get("force"))
try:
record = await self._get_or_refresh_record(
model_id, refresh=refresh, force=force
)
except RateLimitError as exc:
return web.json_response(
{"success": False, "error": str(exc) or "Rate limited"}, status=429
)
if record is None:
return web.json_response(
{"success": False, "error": "Model not tracked"}, status=404
)
# Enrich EA versions with detailed info if needed
record = await self._enrich_early_access_details(record)
overrides = await self._build_version_context(record)
return web.json_response(
{
"success": True,
"record": self._serialize_record(record, version_context=overrides),
}
)
async def _get_or_refresh_record(
self, model_id: int, *, refresh: bool, force: bool
) -> Optional[object]:
record = await self._update_service.get_record(
self._service.model_type, model_id
)
if record and not refresh and not force:
return record
provider = await self._get_civitai_provider()
if provider is None:
return record
return await self._update_service.refresh_single_model(
self._service.model_type,
model_id,
self._service.scanner,
provider,
force_refresh=force or refresh,
)
async def _get_civitai_provider(self):
try:
return await self._metadata_provider_selector("civitai_api")
except Exception as exc: # pragma: no cover - defensive log
self._logger.error(
"Failed to acquire civitai provider: %s", exc, exc_info=True
)
return None
async def _enrich_early_access_details(self, record):
"""Fetch detailed EA info for versions missing exact end time.
Identifies versions with is_early_access=True but no early_access_ends_at,
then fetches detailed info from CivitAI to get the exact end time.
"""
if not record or not record.versions:
return record
# Find versions that need enrichment
versions_needing_update = []
for version in record.versions:
if version.is_early_access and not version.early_access_ends_at:
versions_needing_update.append(version)
if not versions_needing_update:
return record
provider = await self._get_civitai_provider()
if not provider:
return record
# Fetch detailed info for each version needing update
updated_versions = []
for version in versions_needing_update:
try:
version_info, error = await provider.get_model_version_info(
str(version.version_id)
)
if version_info and not error:
ea_ends_at = version_info.get("earlyAccessEndsAt")
if ea_ends_at:
# Create updated version with EA end time
from dataclasses import replace
updated_version = replace(
version, early_access_ends_at=ea_ends_at
)
updated_versions.append(updated_version)
self._logger.debug(
"Enriched EA info for version %s: %s",
version.version_id,
ea_ends_at,
)
except Exception as exc:
self._logger.debug(
"Failed to fetch EA details for version %s: %s",
version.version_id,
exc,
)
if not updated_versions:
return record
# Update record with enriched versions
version_map = {v.version_id: v for v in record.versions}
for updated in updated_versions:
version_map[updated.version_id] = updated
# Create new record with updated versions
from dataclasses import replace
new_record = replace(
record,
versions=list(version_map.values()),
)
# Optionally persist to database for caching
# Note: We don't persist here to avoid side effects; the data will be
# refreshed on next bulk update if still needed
return new_record
async def _collect_models_missing_license(
self,
cache,
target_model_ids: Optional[set[int]],
) -> Dict[int, List[Dict[str, Any]]]:
entries: Dict[int, List[Dict[str, Any]]] = {}
if cache is None:
return entries
raw_data = getattr(cache, "raw_data", None) or []
seen_paths: set[str] = set()
target_set = target_model_ids
for item in raw_data:
if not isinstance(item, Mapping):
continue
file_path = item.get("file_path")
if (
not isinstance(file_path, str)
or not file_path
or file_path in seen_paths
):
continue
seen_paths.add(file_path)
civitai_entry = item.get("civitai")
if not isinstance(civitai_entry, Mapping):
continue
model_id = self._normalize_model_id(civitai_entry.get("modelId"))
if model_id is None:
continue
if target_set is not None and model_id not in target_set:
continue
try:
metadata_obj, should_skip = await MetadataManager.load_metadata(
file_path
)
except Exception as exc:
self._logger.debug("Failed to load metadata for %s: %s", file_path, exc)
continue
if metadata_obj is None or should_skip:
continue
metadata_payload = self._convert_metadata_to_dict(metadata_obj)
civitai_payload = metadata_payload.get("civitai")
if not isinstance(civitai_payload, Mapping):
civitai_payload = {}
model_payload = civitai_payload.get("model")
if not isinstance(model_payload, Mapping):
model_payload = {}
missing = [key for key in LICENSE_FIELDS if key not in model_payload]
if not missing:
continue
civitai_payload["model"] = model_payload
metadata_payload["civitai"] = civitai_payload
entries.setdefault(model_id, []).append(
{"file_path": file_path, "metadata": metadata_payload}
)
return entries
async def _fetch_license_info(
self,
provider,
model_ids: List[int],
) -> Dict[int, Dict[str, Any]]:
if not model_ids:
return {}
BATCH_SIZE = 100
aggregated: Dict[int, Dict[str, Any]] = {}
for start in range(0, len(model_ids), BATCH_SIZE):
chunk = model_ids[start : start + BATCH_SIZE]
response = await provider.get_model_versions_bulk(chunk)
if not isinstance(response, Mapping):
continue
for raw_id, payload in response.items():
normalized_id = self._normalize_model_id(raw_id)
if normalized_id is None or not isinstance(payload, Mapping):
continue
license_data: Dict[str, Any] = {}
for field in LICENSE_FIELDS:
license_data[field] = payload.get(field)
aggregated[normalized_id] = license_data
return aggregated
def _extract_target_model_ids(self, payload: Dict) -> Optional[List[int]]:
if not isinstance(payload, Mapping):
return None
raw_ids = payload.get("modelIds")
if raw_ids is None:
raw_ids = payload.get("model_ids")
if not isinstance(raw_ids, (list, tuple, set)):
return None
normalized: List[int] = []
for candidate in raw_ids:
model_id = self._normalize_model_id(candidate)
if model_id is not None:
normalized.append(model_id)
if not normalized:
return None
return sorted(set(normalized))
@staticmethod
def _convert_metadata_to_dict(metadata: Any) -> Dict[str, Any]:
if metadata is None:
return {}
to_dict = getattr(metadata, "to_dict", None)
if callable(to_dict):
try:
return to_dict()
except Exception:
pass
if isinstance(metadata, Mapping):
return dict(metadata)
return {}
async def _read_json(self, request: web.Request) -> Dict:
if not request.can_read_body:
return {}
try:
return await request.json()
except Exception:
return {}
@staticmethod
def _parse_bool(value) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in {"1", "true", "yes"}
if isinstance(value, (int, float)):
return bool(value)
return False
@staticmethod
def _normalize_model_id(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _serialize_record(
self,
record,
*,
version_context: Optional[Dict[int, Dict[str, Optional[str]]]] = None,
) -> Dict:
context = version_context or {}
# Check user setting for hiding early access versions
hide_early_access = False
if self._settings is not None:
try:
hide_early_access = bool(
self._settings.get("hide_early_access_updates", False)
)
except Exception:
pass
return {
"modelType": record.model_type,
"modelId": record.model_id,
"largestVersionId": record.largest_version_id,
"versionIds": record.version_ids,
"inLibraryVersionIds": record.in_library_version_ids,
"lastCheckedAt": record.last_checked_at,
"shouldIgnore": record.should_ignore_model,
"hasUpdate": record.has_update(hide_early_access=hide_early_access),
"versions": [
self._serialize_version(version, context.get(version.version_id))
for version in record.versions
],
}
@staticmethod
def _serialize_version(
version, context: Optional[Dict[str, Optional[str]]]
) -> Dict:
context = context or {}
preview_override = context.get("preview_override")
preview_url = (
preview_override if preview_override is not None else version.preview_url
)
# Determine if version is currently in early access
# Two-phase detection: use exact end time if available, otherwise fallback to basic flag
is_early_access = False
if version.early_access_ends_at:
try:
from datetime import datetime, timezone
ea_date = datetime.fromisoformat(
version.early_access_ends_at.replace("Z", "+00:00")
)
is_early_access = ea_date > datetime.now(timezone.utc)
except (ValueError, AttributeError):
# If date parsing fails, treat as active EA (conservative)
is_early_access = True
elif getattr(version, "is_early_access", False):
# Fallback to basic EA flag from bulk API
is_early_access = True
return {
"versionId": version.version_id,
"name": version.name,
"baseModel": version.base_model,
"releasedAt": version.released_at,
"sizeBytes": version.size_bytes,
"previewUrl": preview_url,
"isInLibrary": version.is_in_library,
"shouldIgnore": version.should_ignore,
"earlyAccessEndsAt": version.early_access_ends_at,
"isEarlyAccess": is_early_access,
"filePath": context.get("file_path"),
"fileName": context.get("file_name"),
}
async def _build_version_context(
self, record
) -> Dict[int, Dict[str, Optional[str]]]:
context: Dict[int, Dict[str, Optional[str]]] = {}
try:
cache = await self._service.scanner.get_cached_data()
except Exception as exc: # pragma: no cover - defensive logging
self._logger.debug(
"Failed to load cache while building preview overrides: %s", exc
)
return context
version_index = getattr(cache, "version_index", None)
if not version_index:
return context
for version in record.versions:
if not version.is_in_library:
continue
cache_entry = version_index.get(version.version_id)
if isinstance(cache_entry, Mapping):
preview = cache_entry.get("preview_url")
context_entry: Dict[str, Optional[str]] = {
"file_path": cache_entry.get("file_path"),
"file_name": cache_entry.get("file_name"),
"preview_override": None,
}
if isinstance(preview, str) and preview:
context_entry["preview_override"] = config.get_preview_static_url(
preview
)
context[version.version_id] = context_entry
return context
@dataclass
class ModelHandlerSet:
"""Aggregate concrete handlers into a flat mapping."""
page_view: ModelPageView
listing: ModelListingHandler
management: ModelManagementHandler
query: ModelQueryHandler
download: ModelDownloadHandler
civitai: ModelCivitaiHandler
move: ModelMoveHandler
auto_organize: ModelAutoOrganizeHandler
updates: ModelUpdateHandler
def to_route_mapping(
self,
) -> Dict[str, Callable[[web.Request], Awaitable[web.Response]]]:
return {
"handle_models_page": self.page_view.handle,
"get_models": self.listing.get_models,
"delete_model": self.management.delete_model,
"exclude_model": self.management.exclude_model,
"fetch_civitai": self.management.fetch_civitai,
"fetch_all_civitai": self.civitai.fetch_all_civitai,
"relink_civitai": self.management.relink_civitai,
"replace_preview": self.management.replace_preview,
"set_preview_from_url": self.management.set_preview_from_url,
"save_metadata": self.management.save_metadata,
"add_tags": self.management.add_tags,
"rename_model": self.management.rename_model,
"bulk_delete_models": self.management.bulk_delete_models,
"verify_duplicates": self.management.verify_duplicates,
"get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models,
"get_model_types": self.query.get_model_types,
"scan_models": self.query.scan_models,
"get_model_roots": self.query.get_model_roots,
"get_folders": self.query.get_folders,
"get_folder_tree": self.query.get_folder_tree,
"get_unified_folder_tree": self.query.get_unified_folder_tree,
"find_duplicate_models": self.query.find_duplicate_models,
"find_filename_conflicts": self.query.find_filename_conflicts,
"download_model": self.download.download_model,
"download_model_get": self.download.download_model_get,
"cancel_download_get": self.download.cancel_download_get,
"pause_download_get": self.download.pause_download_get,
"resume_download_get": self.download.resume_download_get,
"get_download_progress": self.download.get_download_progress,
"get_civitai_versions": self.civitai.get_civitai_versions,
"get_civitai_model_by_version": self.civitai.get_civitai_model_by_version,
"get_civitai_model_by_hash": self.civitai.get_civitai_model_by_hash,
"move_model": self.move.move_model,
"move_models_bulk": self.move.move_models_bulk,
"auto_organize_models": self.auto_organize.auto_organize_models,
"get_auto_organize_progress": self.auto_organize.get_auto_organize_progress,
"get_model_notes": self.query.get_model_notes,
"get_model_preview_url": self.query.get_model_preview_url,
"get_model_civitai_url": self.query.get_model_civitai_url,
"get_model_metadata": self.query.get_model_metadata,
"get_model_description": self.query.get_model_description,
"get_relative_paths": self.query.get_relative_paths,
"refresh_model_updates": self.updates.refresh_model_updates,
"fetch_missing_civitai_license_data": self.updates.fetch_missing_civitai_license_data,
"set_model_update_ignore": self.updates.set_model_update_ignore,
"set_version_update_ignore": self.updates.set_version_update_ignore,
"get_model_update_status": self.updates.get_model_update_status,
"get_model_versions": self.updates.get_model_versions,
"cancel_task": self.query.cancel_task,
}