From 8cf762ffd3500e2439b037665cef4691be8f610f Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 9 Oct 2025 11:49:41 +0800 Subject: [PATCH] feat(misc): add civitai user model lookup --- py/routes/handlers/misc_handlers.py | 107 +++++++++++++- py/routes/misc_route_registrar.py | 1 + py/services/civitai_client.py | 36 ++++- py/services/model_metadata_provider.py | 42 +++++- py/utils/constants.py | 7 + tests/routes/test_misc_routes.py | 186 +++++++++++++++++++++++++ 6 files changed, 372 insertions(+), 7 deletions(-) diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 0d3ccacc..4c41a06b 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -27,7 +27,13 @@ from ...services.service_registry import ServiceRegistry from ...services.settings_manager import get_settings_manager from ...services.websocket_manager import ws_manager from ...services.downloader import get_downloader -from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS +from ...utils.constants import ( + CIVITAI_USER_MODEL_TYPES, + DEFAULT_NODE_COLOR, + NODE_TYPES, + SUPPORTED_MEDIA_EXTENSIONS, + VALID_LORA_TYPES, +) from ...utils.example_images_paths import is_valid_example_images_root from ...utils.lora_metadata import extract_trained_words from ...utils.usage_stats import UsageStats @@ -611,6 +617,104 @@ class ModelLibraryHandler: logger.error("Failed to get model versions status: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_civitai_user_models(self, request: web.Request) -> web.Response: + try: + username = request.query.get("username") + if not username: + return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400) + + metadata_provider = await self._metadata_provider_factory() + if not metadata_provider: + return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503) + + try: + models = await metadata_provider.get_user_models(username) + except NotImplementedError: + return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501) + + if models is None: + return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502) + + if not isinstance(models, list): + models = [] + + lora_scanner = await self._service_registry.get_lora_scanner() + checkpoint_scanner = await self._service_registry.get_checkpoint_scanner() + embedding_scanner = await self._service_registry.get_embedding_scanner() + + normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES} + lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES} + + type_scanner_map: Dict[str, object | None] = { + **{alias: lora_scanner for alias in lora_type_aliases}, + "checkpoint": checkpoint_scanner, + "textualinversion": embedding_scanner, + } + + versions: list[dict] = [] + for model in models: + if not isinstance(model, dict): + continue + + model_type = str(model.get("type", "")).lower() + if model_type not in normalized_allowed_types: + continue + + scanner = type_scanner_map.get(model_type) + if scanner is None: + return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503) + + tags_value = model.get("tags") + tags = tags_value if isinstance(tags_value, list) else [] + model_id = model.get("id") + try: + model_id_int = int(model_id) + except (TypeError, ValueError): + continue + model_name = model.get("name", "") + + versions_data = model.get("modelVersions") + if not isinstance(versions_data, list): + continue + + for version in versions_data: + if not isinstance(version, dict): + continue + + version_id = version.get("id") + try: + version_id_int = int(version_id) + except (TypeError, ValueError): + continue + + if await scanner.check_model_version_exists(version_id_int): + continue + + images = version.get("images") or [] + thumbnail_url = None + if images and isinstance(images, list): + first_image = images[0] + if isinstance(first_image, dict): + thumbnail_url = first_image.get("url") + + versions.append( + { + "modelId": model_id_int, + "versionId": version_id_int, + "modelName": model_name, + "versionName": version.get("name", ""), + "type": model.get("type"), + "tags": tags, + "baseModel": version.get("baseModel"), + "thumbnailUrl": thumbnail_url, + } + ) + + return web.json_response({"success": True, "username": username, "versions": versions}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to get Civitai user models: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + class MetadataArchiveHandler: def __init__( @@ -844,6 +948,7 @@ class MiscHandlerSet: "register_nodes": self.node_registry.register_nodes, "get_registry": self.node_registry.get_registry, "check_model_exists": self.model_library.check_model_exists, + "get_civitai_user_models": self.model_library.get_civitai_user_models, "download_metadata_archive": self.metadata_archive.download_metadata_archive, "remove_metadata_archive": self.metadata_archive.remove_metadata_archive, "get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status, diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index 06780a48..7cf5c62b 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), + RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"), RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"), RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"), RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"), diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 5598d7d7..e860994b 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -354,7 +354,7 @@ class CivitaiClient: async def get_image_info(self, image_id: str) -> Optional[Dict]: """Fetch image information from Civitai API - + Args: image_id: The Civitai image ID @@ -385,3 +385,37 @@ class CivitaiClient: error_msg = f"Error fetching image info: {e}" logger.error(error_msg) return None + + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + """Fetch all models for a specific Civitai user.""" + if not username: + return None + + try: + downloader = await get_downloader() + url = f"{self.base_url}/models?username={username}" + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if not success: + logger.error("Failed to fetch models for %s: %s", username, result) + return None + + items = result.get("items") if isinstance(result, dict) else None + if not isinstance(items, list): + return [] + + for model in items: + versions = model.get("modelVersions") + if not isinstance(versions, list): + continue + for version in versions: + self._remove_comfy_metadata(version) + + return items + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error fetching models for %s: %s", username, exc) + return None diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 3099b5fc..99b3488c 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import json import logging -from typing import Optional, Dict, Tuple, Any +from typing import Optional, Dict, Tuple, Any, List from .downloader import get_downloader try: @@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC): """Fetch model version metadata""" pass + @abstractmethod + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + """Fetch models owned by the specified user""" + pass + class CivitaiModelMetadataProvider(ModelMetadataProvider): """Provider that uses Civitai API for metadata""" @@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider): async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: return await self.client.get_model_version_info(version_id) + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + return await self.client.get_user_models(username) + class CivArchiveModelMetadataProvider(ModelMetadataProvider): """Provider that uses CivArchive HTML page parsing for metadata""" @@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider): """Not supported by CivArchive provider - requires both model_id and version_id""" return None, "CivArchive provider requires both model_id and version_id" + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + """Not supported by CivArchive provider""" + return None + class SQLiteModelMetadataProvider(ModelMetadataProvider): """Provider that uses SQLite database for metadata""" @@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): """Fetch model version metadata from SQLite database""" async with self._aiosqlite.connect(self.db_path) as db: db.row_factory = self._aiosqlite.Row - + # Get version details version_query = "SELECT model_id FROM model_versions WHERE id = ?" cursor = await db.execute(version_query, (version_id,)) version_row = await cursor.fetchone() - + if not version_row: return None, "Model version not found" - + model_id = version_row['model_id'] - + # Build complete version data with model info version_data = await self._get_version_with_model_data(db, model_id, version_id) return version_data, None + + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + """Listing models by username is not supported for archive database""" + return None async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]: """Helper to build version data with model information""" @@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider): continue return None, "No provider could retrieve the data" + async def get_user_models(self, username: str) -> Optional[List[Dict]]: + for provider in self.providers: + try: + result = await provider.get_user_models(username) + if result is not None: + return result + except Exception as e: + logger.debug(f"Provider failed for get_user_models: {e}") + continue + return None + class ModelMetadataProviderManager: """Manager for selecting and using model metadata providers""" @@ -522,6 +549,11 @@ class ModelMetadataProviderManager: """Fetch model version info using specified or default provider""" provider = self._get_provider(provider_name) return await provider.get_model_version_info(version_id) + + async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]: + """Fetch models owned by the specified user""" + provider = self._get_provider(provider_name) + return await provider.get_user_models(username) def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider: """Get provider by name or default provider""" diff --git a/py/utils/constants.py b/py/utils/constants.py index 243badff..91ac3c54 100644 --- a/py/utils/constants.py +++ b/py/utils/constants.py @@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = { # Valid Lora types VALID_LORA_TYPES = ['lora', 'locon', 'dora'] +# Supported Civitai model types for user model queries (case-insensitive) +CIVITAI_USER_MODEL_TYPES = [ + *VALID_LORA_TYPES, + 'textualinversion', + 'checkpoint', +] + # Auto-organize settings AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 5ce9b822..0e597d35 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -6,6 +6,7 @@ from aiohttp import web from py.routes.handlers.misc_handlers import ( LoraCodeHandler, + ModelLibraryHandler, NodeRegistry, NodeRegistryHandler, ServiceRegistryAdapter, @@ -266,10 +267,34 @@ async def fake_scanner_factory(): return FakeScanner() +class FakeExistenceScanner: + def __init__(self, existing=None): + self._existing = set(existing or []) + + async def check_model_version_exists(self, version_id): + return version_id in self._existing + + async def get_model_versions_by_id(self, _model_id): + return [] + + class FakeMetadataProvider: async def get_model_versions(self, _model_id): return {"modelVersions": [], "name": "", "type": "lora"} + async def get_user_models(self, _username): + return [] + + +class FakeUserModelsProvider(FakeMetadataProvider): + def __init__(self, models): + self.models = models + self.received_usernames: list[str] = [] + + async def get_user_models(self, username): + self.received_usernames.append(username) + return self.models + async def fake_metadata_provider_factory(): return FakeMetadataProvider() @@ -339,6 +364,167 @@ async def test_misc_routes_bind_produces_expected_handlers(): assert set(mapping.keys()) == expected_names +@pytest.mark.asyncio +async def test_get_civitai_user_models_filters_versions(): + models = [ + { + "id": 1, + "name": "Model A", + "type": "LORA", + "tags": ["style"], + "modelVersions": [ + { + "id": 100, + "name": "v1", + "baseModel": "Flux.1", + "images": [{"url": "http://example.com/a1.jpg"}], + }, + { + "id": 101, + "name": "v2", + "baseModel": "Flux.1", + "images": [{"url": "http://example.com/a2.jpg"}], + }, + ], + }, + { + "id": 2, + "name": "Embedding", + "type": "TextualInversion", + "tags": ["embedding"], + "modelVersions": [ + { + "id": 200, + "name": "v1", + "baseModel": None, + "images": [{"url": "http://example.com/e1.jpg"}], + }, + { + "id": 202, + "name": "v2", + "baseModel": None, + }, + ], + }, + { + "id": 3, + "name": "Checkpoint", + "type": "Checkpoint", + "tags": ["checkpoint"], + "modelVersions": [ + { + "id": 300, + "name": "v1", + "baseModel": "SDXL", + "images": [], + } + ], + }, + { + "id": 4, + "name": "Unsupported", + "type": "Other", + "modelVersions": [ + { + "id": 400, + "name": "v1", + } + ], + }, + ] + + provider = FakeUserModelsProvider(models) + + async def provider_factory(): + return provider + + lora_scanner = FakeExistenceScanner({101}) + checkpoint_scanner = FakeExistenceScanner() + embedding_scanner = FakeExistenceScanner({202}) + + async def lora_factory(): + return lora_scanner + + async def checkpoint_factory(): + return checkpoint_scanner + + async def embedding_factory(): + return embedding_scanner + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=lora_factory, + get_checkpoint_scanner=checkpoint_factory, + get_embedding_scanner=embedding_factory, + ), + metadata_provider_factory=provider_factory, + ) + + response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"})) + payload = json.loads(response.text) + + assert payload["success"] is True + assert payload["username"] == "pixel" + assert payload["versions"] == [ + { + "modelId": 1, + "versionId": 100, + "modelName": "Model A", + "versionName": "v1", + "type": "LORA", + "tags": ["style"], + "baseModel": "Flux.1", + "thumbnailUrl": "http://example.com/a1.jpg", + }, + { + "modelId": 2, + "versionId": 200, + "modelName": "Embedding", + "versionName": "v1", + "type": "TextualInversion", + "tags": ["embedding"], + "baseModel": None, + "thumbnailUrl": "http://example.com/e1.jpg", + }, + { + "modelId": 3, + "versionId": 300, + "modelName": "Checkpoint", + "versionName": "v1", + "type": "Checkpoint", + "tags": ["checkpoint"], + "baseModel": "SDXL", + "thumbnailUrl": None, + }, + ] + + assert provider.received_usernames == ["pixel"] + + +@pytest.mark.asyncio +async def test_get_civitai_user_models_requires_username(): + provider = FakeUserModelsProvider([]) + + async def provider_factory(): + return provider + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=fake_scanner_factory, + get_checkpoint_scanner=fake_scanner_factory, + get_embedding_scanner=fake_scanner_factory, + ), + metadata_provider_factory=provider_factory, + ) + + response = await handler.get_civitai_user_models(FakeRequest()) + payload = json.loads(response.text) + + assert response.status == 400 + assert payload["success"] is False + assert "username" in payload["error"].lower() + + def test_ensure_handler_mapping_caches_result(): call_records = []