mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(misc): add civitai user model lookup
This commit is contained in:
@@ -27,7 +27,13 @@ from ...services.service_registry import ServiceRegistry
|
||||
from ...services.settings_manager import get_settings_manager
|
||||
from ...services.websocket_manager import ws_manager
|
||||
from ...services.downloader import get_downloader
|
||||
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ...utils.constants import (
|
||||
CIVITAI_USER_MODEL_TYPES,
|
||||
DEFAULT_NODE_COLOR,
|
||||
NODE_TYPES,
|
||||
SUPPORTED_MEDIA_EXTENSIONS,
|
||||
VALID_LORA_TYPES,
|
||||
)
|
||||
from ...utils.example_images_paths import is_valid_example_images_root
|
||||
from ...utils.lora_metadata import extract_trained_words
|
||||
from ...utils.usage_stats import UsageStats
|
||||
@@ -611,6 +617,104 @@ class ModelLibraryHandler:
|
||||
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
username = request.query.get("username")
|
||||
if not username:
|
||||
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
|
||||
|
||||
metadata_provider = await self._metadata_provider_factory()
|
||||
if not metadata_provider:
|
||||
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
|
||||
|
||||
try:
|
||||
models = await metadata_provider.get_user_models(username)
|
||||
except NotImplementedError:
|
||||
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
|
||||
|
||||
if models is None:
|
||||
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
|
||||
|
||||
if not isinstance(models, list):
|
||||
models = []
|
||||
|
||||
lora_scanner = await self._service_registry.get_lora_scanner()
|
||||
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
|
||||
embedding_scanner = await self._service_registry.get_embedding_scanner()
|
||||
|
||||
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
|
||||
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
|
||||
|
||||
type_scanner_map: Dict[str, object | None] = {
|
||||
**{alias: lora_scanner for alias in lora_type_aliases},
|
||||
"checkpoint": checkpoint_scanner,
|
||||
"textualinversion": embedding_scanner,
|
||||
}
|
||||
|
||||
versions: list[dict] = []
|
||||
for model in models:
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
|
||||
model_type = str(model.get("type", "")).lower()
|
||||
if model_type not in normalized_allowed_types:
|
||||
continue
|
||||
|
||||
scanner = type_scanner_map.get(model_type)
|
||||
if scanner is None:
|
||||
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
|
||||
|
||||
tags_value = model.get("tags")
|
||||
tags = tags_value if isinstance(tags_value, list) else []
|
||||
model_id = model.get("id")
|
||||
try:
|
||||
model_id_int = int(model_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
model_name = model.get("name", "")
|
||||
|
||||
versions_data = model.get("modelVersions")
|
||||
if not isinstance(versions_data, list):
|
||||
continue
|
||||
|
||||
for version in versions_data:
|
||||
if not isinstance(version, dict):
|
||||
continue
|
||||
|
||||
version_id = version.get("id")
|
||||
try:
|
||||
version_id_int = int(version_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
if await scanner.check_model_version_exists(version_id_int):
|
||||
continue
|
||||
|
||||
images = version.get("images") or []
|
||||
thumbnail_url = None
|
||||
if images and isinstance(images, list):
|
||||
first_image = images[0]
|
||||
if isinstance(first_image, dict):
|
||||
thumbnail_url = first_image.get("url")
|
||||
|
||||
versions.append(
|
||||
{
|
||||
"modelId": model_id_int,
|
||||
"versionId": version_id_int,
|
||||
"modelName": model_name,
|
||||
"versionName": version.get("name", ""),
|
||||
"type": model.get("type"),
|
||||
"tags": tags,
|
||||
"baseModel": version.get("baseModel"),
|
||||
"thumbnailUrl": thumbnail_url,
|
||||
}
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "username": username, "versions": versions})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class MetadataArchiveHandler:
|
||||
def __init__(
|
||||
@@ -844,6 +948,7 @@ class MiscHandlerSet:
|
||||
"register_nodes": self.node_registry.register_nodes,
|
||||
"get_registry": self.node_registry.get_registry,
|
||||
"check_model_exists": self.model_library.check_model_exists,
|
||||
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
||||
|
||||
@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
|
||||
@@ -354,7 +354,7 @@ class CivitaiClient:
|
||||
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
|
||||
Args:
|
||||
image_id: The Civitai image ID
|
||||
|
||||
@@ -385,3 +385,37 @@ class CivitaiClient:
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Fetch all models for a specific Civitai user."""
|
||||
if not username:
|
||||
return None
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/models?username={username}"
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||
return None
|
||||
|
||||
items = result.get("items") if isinstance(result, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
|
||||
for model in items:
|
||||
versions = model.get("modelVersions")
|
||||
if not isinstance(versions, list):
|
||||
continue
|
||||
for version in versions:
|
||||
self._remove_comfy_metadata(version)
|
||||
|
||||
return items
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error fetching models for %s: %s", username, exc)
|
||||
return None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from typing import Optional, Dict, Tuple, Any, List
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
|
||||
"""Fetch model version metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Fetch models owned by the specified user"""
|
||||
pass
|
||||
|
||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses Civitai API for metadata"""
|
||||
|
||||
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
return await self.client.get_user_models(username)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
|
||||
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
|
||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses SQLite database for metadata"""
|
||||
|
||||
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Fetch model version metadata from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
|
||||
# Get version details
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
cursor = await db.execute(version_query, (version_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
|
||||
if not version_row:
|
||||
return None, "Model version not found"
|
||||
|
||||
|
||||
model_id = version_row['model_id']
|
||||
|
||||
|
||||
# Build complete version data with model info
|
||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return version_data, None
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Listing models by username is not supported for archive database"""
|
||||
return None
|
||||
|
||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||
"""Helper to build version data with model information"""
|
||||
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
continue
|
||||
return None, "No provider could retrieve the data"
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_user_models(username)
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_user_models: {e}")
|
||||
continue
|
||||
return None
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
|
||||
"""Fetch model version info using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version_info(version_id)
|
||||
|
||||
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
|
||||
"""Fetch models owned by the specified user"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_user_models(username)
|
||||
|
||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||
"""Get provider by name or default provider"""
|
||||
|
||||
@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
||||
# Valid Lora types
|
||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||
|
||||
# Supported Civitai model types for user model queries (case-insensitive)
|
||||
CIVITAI_USER_MODEL_TYPES = [
|
||||
*VALID_LORA_TYPES,
|
||||
'textualinversion',
|
||||
'checkpoint',
|
||||
]
|
||||
|
||||
# Auto-organize settings
|
||||
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from aiohttp import web
|
||||
|
||||
from py.routes.handlers.misc_handlers import (
|
||||
LoraCodeHandler,
|
||||
ModelLibraryHandler,
|
||||
NodeRegistry,
|
||||
NodeRegistryHandler,
|
||||
ServiceRegistryAdapter,
|
||||
@@ -266,10 +267,34 @@ async def fake_scanner_factory():
|
||||
return FakeScanner()
|
||||
|
||||
|
||||
class FakeExistenceScanner:
|
||||
def __init__(self, existing=None):
|
||||
self._existing = set(existing or [])
|
||||
|
||||
async def check_model_version_exists(self, version_id):
|
||||
return version_id in self._existing
|
||||
|
||||
async def get_model_versions_by_id(self, _model_id):
|
||||
return []
|
||||
|
||||
|
||||
class FakeMetadataProvider:
|
||||
async def get_model_versions(self, _model_id):
|
||||
return {"modelVersions": [], "name": "", "type": "lora"}
|
||||
|
||||
async def get_user_models(self, _username):
|
||||
return []
|
||||
|
||||
|
||||
class FakeUserModelsProvider(FakeMetadataProvider):
|
||||
def __init__(self, models):
|
||||
self.models = models
|
||||
self.received_usernames: list[str] = []
|
||||
|
||||
async def get_user_models(self, username):
|
||||
self.received_usernames.append(username)
|
||||
return self.models
|
||||
|
||||
|
||||
async def fake_metadata_provider_factory():
|
||||
return FakeMetadataProvider()
|
||||
@@ -339,6 +364,167 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
||||
assert set(mapping.keys()) == expected_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_filters_versions():
|
||||
models = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Model A",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 100,
|
||||
"name": "v1",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [{"url": "http://example.com/a1.jpg"}],
|
||||
},
|
||||
{
|
||||
"id": 101,
|
||||
"name": "v2",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [{"url": "http://example.com/a2.jpg"}],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Embedding",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 200,
|
||||
"name": "v1",
|
||||
"baseModel": None,
|
||||
"images": [{"url": "http://example.com/e1.jpg"}],
|
||||
},
|
||||
{
|
||||
"id": 202,
|
||||
"name": "v2",
|
||||
"baseModel": None,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "Checkpoint",
|
||||
"type": "Checkpoint",
|
||||
"tags": ["checkpoint"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 300,
|
||||
"name": "v1",
|
||||
"baseModel": "SDXL",
|
||||
"images": [],
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "Unsupported",
|
||||
"type": "Other",
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 400,
|
||||
"name": "v1",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
provider = FakeUserModelsProvider(models)
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
lora_scanner = FakeExistenceScanner({101})
|
||||
checkpoint_scanner = FakeExistenceScanner()
|
||||
embedding_scanner = FakeExistenceScanner({202})
|
||||
|
||||
async def lora_factory():
|
||||
return lora_scanner
|
||||
|
||||
async def checkpoint_factory():
|
||||
return checkpoint_scanner
|
||||
|
||||
async def embedding_factory():
|
||||
return embedding_scanner
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=lora_factory,
|
||||
get_checkpoint_scanner=checkpoint_factory,
|
||||
get_embedding_scanner=embedding_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["username"] == "pixel"
|
||||
assert payload["versions"] == [
|
||||
{
|
||||
"modelId": 1,
|
||||
"versionId": 100,
|
||||
"modelName": "Model A",
|
||||
"versionName": "v1",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"baseModel": "Flux.1",
|
||||
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||
},
|
||||
{
|
||||
"modelId": 2,
|
||||
"versionId": 200,
|
||||
"modelName": "Embedding",
|
||||
"versionName": "v1",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"baseModel": None,
|
||||
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||
},
|
||||
{
|
||||
"modelId": 3,
|
||||
"versionId": 300,
|
||||
"modelName": "Checkpoint",
|
||||
"versionName": "v1",
|
||||
"type": "Checkpoint",
|
||||
"tags": ["checkpoint"],
|
||||
"baseModel": "SDXL",
|
||||
"thumbnailUrl": None,
|
||||
},
|
||||
]
|
||||
|
||||
assert provider.received_usernames == ["pixel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_requires_username():
|
||||
provider = FakeUserModelsProvider([])
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest())
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "username" in payload["error"].lower()
|
||||
|
||||
|
||||
def test_ensure_handler_mapping_caches_result():
|
||||
call_records = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user