Merge pull request #544 from willmiao/codex/add-endpoint-to-fetch-civitai-user-models

Add endpoint to fetch Civitai user models
This commit is contained in:
pixelpaws
2025-10-09 11:56:57 +08:00
committed by GitHub
6 changed files with 372 additions and 7 deletions

View File

@@ -27,7 +27,13 @@ from ...services.service_registry import ServiceRegistry
from ...services.settings_manager import get_settings_manager from ...services.settings_manager import get_settings_manager
from ...services.websocket_manager import ws_manager from ...services.websocket_manager import ws_manager
from ...services.downloader import get_downloader from ...services.downloader import get_downloader
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS from ...utils.constants import (
CIVITAI_USER_MODEL_TYPES,
DEFAULT_NODE_COLOR,
NODE_TYPES,
SUPPORTED_MEDIA_EXTENSIONS,
VALID_LORA_TYPES,
)
from ...utils.example_images_paths import is_valid_example_images_root from ...utils.example_images_paths import is_valid_example_images_root
from ...utils.lora_metadata import extract_trained_words from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats from ...utils.usage_stats import UsageStats
@@ -611,6 +617,104 @@ class ModelLibraryHandler:
logger.error("Failed to get model versions status: %s", exc, exc_info=True) logger.error("Failed to get model versions status: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
try:
username = request.query.get("username")
if not username:
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
metadata_provider = await self._metadata_provider_factory()
if not metadata_provider:
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
try:
models = await metadata_provider.get_user_models(username)
except NotImplementedError:
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
if models is None:
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
if not isinstance(models, list):
models = []
lora_scanner = await self._service_registry.get_lora_scanner()
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
embedding_scanner = await self._service_registry.get_embedding_scanner()
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
type_scanner_map: Dict[str, object | None] = {
**{alias: lora_scanner for alias in lora_type_aliases},
"checkpoint": checkpoint_scanner,
"textualinversion": embedding_scanner,
}
versions: list[dict] = []
for model in models:
if not isinstance(model, dict):
continue
model_type = str(model.get("type", "")).lower()
if model_type not in normalized_allowed_types:
continue
scanner = type_scanner_map.get(model_type)
if scanner is None:
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
tags_value = model.get("tags")
tags = tags_value if isinstance(tags_value, list) else []
model_id = model.get("id")
try:
model_id_int = int(model_id)
except (TypeError, ValueError):
continue
model_name = model.get("name", "")
versions_data = model.get("modelVersions")
if not isinstance(versions_data, list):
continue
for version in versions_data:
if not isinstance(version, dict):
continue
version_id = version.get("id")
try:
version_id_int = int(version_id)
except (TypeError, ValueError):
continue
if await scanner.check_model_version_exists(version_id_int):
continue
images = version.get("images") or []
thumbnail_url = None
if images and isinstance(images, list):
first_image = images[0]
if isinstance(first_image, dict):
thumbnail_url = first_image.get("url")
versions.append(
{
"modelId": model_id_int,
"versionId": version_id_int,
"modelName": model_name,
"versionName": version.get("name", ""),
"type": model.get("type"),
"tags": tags,
"baseModel": version.get("baseModel"),
"thumbnailUrl": thumbnail_url,
}
)
return web.json_response({"success": True, "username": username, "versions": versions})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class MetadataArchiveHandler: class MetadataArchiveHandler:
def __init__( def __init__(
@@ -844,6 +948,7 @@ class MiscHandlerSet:
"register_nodes": self.node_registry.register_nodes, "register_nodes": self.node_registry.register_nodes,
"get_registry": self.node_registry.get_registry, "get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists, "check_model_exists": self.model_library.check_model_exists,
"get_civitai_user_models": self.model_library.get_civitai_user_models,
"download_metadata_archive": self.metadata_archive.download_metadata_archive, "download_metadata_archive": self.metadata_archive.download_metadata_archive,
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive, "remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status, "get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,

View File

@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"), RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"), RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"), RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"), RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),

View File

@@ -354,7 +354,7 @@ class CivitaiClient:
async def get_image_info(self, image_id: str) -> Optional[Dict]: async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API """Fetch image information from Civitai API
Args: Args:
image_id: The Civitai image ID image_id: The Civitai image ID
@@ -385,3 +385,37 @@ class CivitaiClient:
error_msg = f"Error fetching image info: {e}" error_msg = f"Error fetching image info: {e}"
logger.error(error_msg) logger.error(error_msg)
return None return None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch all models for a specific Civitai user."""
if not username:
return None
try:
downloader = await get_downloader()
url = f"{self.base_url}/models?username={username}"
success, result = await downloader.make_request(
'GET',
url,
use_auth=True
)
if not success:
logger.error("Failed to fetch models for %s: %s", username, result)
return None
items = result.get("items") if isinstance(result, dict) else None
if not isinstance(items, list):
return []
for model in items:
versions = model.get("modelVersions")
if not isinstance(versions, list):
continue
for version in versions:
self._remove_comfy_metadata(version)
return items
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error fetching models for %s: %s", username, exc)
return None

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import json import json
import logging import logging
from typing import Optional, Dict, Tuple, Any from typing import Optional, Dict, Tuple, Any, List
from .downloader import get_downloader from .downloader import get_downloader
try: try:
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
"""Fetch model version metadata""" """Fetch model version metadata"""
pass pass
@abstractmethod
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider): class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata""" """Provider that uses Civitai API for metadata"""
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id) return await self.client.get_model_version_info(version_id)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self.client.get_user_models(username)
class CivArchiveModelMetadataProvider(ModelMetadataProvider): class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata""" """Provider that uses CivArchive HTML page parsing for metadata"""
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Not supported by CivArchive provider - requires both model_id and version_id""" """Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id" return None, "CivArchive provider requires both model_id and version_id"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Not supported by CivArchive provider"""
return None
class SQLiteModelMetadataProvider(ModelMetadataProvider): class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata""" """Provider that uses SQLite database for metadata"""
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Fetch model version metadata from SQLite database""" """Fetch model version metadata from SQLite database"""
async with self._aiosqlite.connect(self.db_path) as db: async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row db.row_factory = self._aiosqlite.Row
# Get version details # Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?" version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,)) cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone() version_row = await cursor.fetchone()
if not version_row: if not version_row:
return None, "Model version not found" return None, "Model version not found"
model_id = version_row['model_id'] model_id = version_row['model_id']
# Build complete version data with model info # Build complete version data with model info
version_data = await self._get_version_with_model_data(db, model_id, version_id) version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None return version_data, None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Listing models by username is not supported for archive database"""
return None
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]: async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information""" """Helper to build version data with model information"""
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
continue continue
return None, "No provider could retrieve the data" return None, "No provider could retrieve the data"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
for provider in self.providers:
try:
result = await provider.get_user_models(username)
if result is not None:
return result
except Exception as e:
logger.debug(f"Provider failed for get_user_models: {e}")
continue
return None
class ModelMetadataProviderManager: class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers""" """Manager for selecting and using model metadata providers"""
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
"""Fetch model version info using specified or default provider""" """Fetch model version info using specified or default provider"""
provider = self._get_provider(provider_name) provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id) return await provider.get_model_version_info(version_id)
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
provider = self._get_provider(provider_name)
return await provider.get_user_models(username)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider: def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider""" """Get provider by name or default provider"""

View File

@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
# Valid Lora types # Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon', 'dora'] VALID_LORA_TYPES = ['lora', 'locon', 'dora']
# Supported Civitai model types for user model queries (case-insensitive)
CIVITAI_USER_MODEL_TYPES = [
*VALID_LORA_TYPES,
'textualinversion',
'checkpoint',
]
# Auto-organize settings # Auto-organize settings
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system

View File

@@ -6,6 +6,7 @@ from aiohttp import web
from py.routes.handlers.misc_handlers import ( from py.routes.handlers.misc_handlers import (
LoraCodeHandler, LoraCodeHandler,
ModelLibraryHandler,
NodeRegistry, NodeRegistry,
NodeRegistryHandler, NodeRegistryHandler,
ServiceRegistryAdapter, ServiceRegistryAdapter,
@@ -266,10 +267,34 @@ async def fake_scanner_factory():
return FakeScanner() return FakeScanner()
class FakeExistenceScanner:
def __init__(self, existing=None):
self._existing = set(existing or [])
async def check_model_version_exists(self, version_id):
return version_id in self._existing
async def get_model_versions_by_id(self, _model_id):
return []
class FakeMetadataProvider: class FakeMetadataProvider:
async def get_model_versions(self, _model_id): async def get_model_versions(self, _model_id):
return {"modelVersions": [], "name": "", "type": "lora"} return {"modelVersions": [], "name": "", "type": "lora"}
async def get_user_models(self, _username):
return []
class FakeUserModelsProvider(FakeMetadataProvider):
def __init__(self, models):
self.models = models
self.received_usernames: list[str] = []
async def get_user_models(self, username):
self.received_usernames.append(username)
return self.models
async def fake_metadata_provider_factory(): async def fake_metadata_provider_factory():
return FakeMetadataProvider() return FakeMetadataProvider()
@@ -339,6 +364,167 @@ async def test_misc_routes_bind_produces_expected_handlers():
assert set(mapping.keys()) == expected_names assert set(mapping.keys()) == expected_names
@pytest.mark.asyncio
async def test_get_civitai_user_models_filters_versions():
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "v1",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a1.jpg"}],
},
{
"id": 101,
"name": "v2",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a2.jpg"}],
},
],
},
{
"id": 2,
"name": "Embedding",
"type": "TextualInversion",
"tags": ["embedding"],
"modelVersions": [
{
"id": 200,
"name": "v1",
"baseModel": None,
"images": [{"url": "http://example.com/e1.jpg"}],
},
{
"id": 202,
"name": "v2",
"baseModel": None,
},
],
},
{
"id": 3,
"name": "Checkpoint",
"type": "Checkpoint",
"tags": ["checkpoint"],
"modelVersions": [
{
"id": 300,
"name": "v1",
"baseModel": "SDXL",
"images": [],
}
],
},
{
"id": 4,
"name": "Unsupported",
"type": "Other",
"modelVersions": [
{
"id": 400,
"name": "v1",
}
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
lora_scanner = FakeExistenceScanner({101})
checkpoint_scanner = FakeExistenceScanner()
embedding_scanner = FakeExistenceScanner({202})
async def lora_factory():
return lora_scanner
async def checkpoint_factory():
return checkpoint_scanner
async def embedding_factory():
return embedding_scanner
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["username"] == "pixel"
assert payload["versions"] == [
{
"modelId": 1,
"versionId": 100,
"modelName": "Model A",
"versionName": "v1",
"type": "LORA",
"tags": ["style"],
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a1.jpg",
},
{
"modelId": 2,
"versionId": 200,
"modelName": "Embedding",
"versionName": "v1",
"type": "TextualInversion",
"tags": ["embedding"],
"baseModel": None,
"thumbnailUrl": "http://example.com/e1.jpg",
},
{
"modelId": 3,
"versionId": 300,
"modelName": "Checkpoint",
"versionName": "v1",
"type": "Checkpoint",
"tags": ["checkpoint"],
"baseModel": "SDXL",
"thumbnailUrl": None,
},
]
assert provider.received_usernames == ["pixel"]
@pytest.mark.asyncio
async def test_get_civitai_user_models_requires_username():
provider = FakeUserModelsProvider([])
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest())
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "username" in payload["error"].lower()
def test_ensure_handler_mapping_caches_result(): def test_ensure_handler_mapping_caches_result():
call_records = [] call_records = []