mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat(misc): add civitai user model lookup
This commit is contained in:
@@ -27,7 +27,13 @@ from ...services.service_registry import ServiceRegistry
|
|||||||
from ...services.settings_manager import get_settings_manager
|
from ...services.settings_manager import get_settings_manager
|
||||||
from ...services.websocket_manager import ws_manager
|
from ...services.websocket_manager import ws_manager
|
||||||
from ...services.downloader import get_downloader
|
from ...services.downloader import get_downloader
|
||||||
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
|
from ...utils.constants import (
|
||||||
|
CIVITAI_USER_MODEL_TYPES,
|
||||||
|
DEFAULT_NODE_COLOR,
|
||||||
|
NODE_TYPES,
|
||||||
|
SUPPORTED_MEDIA_EXTENSIONS,
|
||||||
|
VALID_LORA_TYPES,
|
||||||
|
)
|
||||||
from ...utils.example_images_paths import is_valid_example_images_root
|
from ...utils.example_images_paths import is_valid_example_images_root
|
||||||
from ...utils.lora_metadata import extract_trained_words
|
from ...utils.lora_metadata import extract_trained_words
|
||||||
from ...utils.usage_stats import UsageStats
|
from ...utils.usage_stats import UsageStats
|
||||||
@@ -611,6 +617,104 @@ class ModelLibraryHandler:
|
|||||||
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
username = request.query.get("username")
|
||||||
|
if not username:
|
||||||
|
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
|
||||||
|
|
||||||
|
metadata_provider = await self._metadata_provider_factory()
|
||||||
|
if not metadata_provider:
|
||||||
|
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
|
||||||
|
|
||||||
|
try:
|
||||||
|
models = await metadata_provider.get_user_models(username)
|
||||||
|
except NotImplementedError:
|
||||||
|
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
|
||||||
|
|
||||||
|
if models is None:
|
||||||
|
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
|
||||||
|
|
||||||
|
if not isinstance(models, list):
|
||||||
|
models = []
|
||||||
|
|
||||||
|
lora_scanner = await self._service_registry.get_lora_scanner()
|
||||||
|
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
|
||||||
|
embedding_scanner = await self._service_registry.get_embedding_scanner()
|
||||||
|
|
||||||
|
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
|
||||||
|
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
|
||||||
|
|
||||||
|
type_scanner_map: Dict[str, object | None] = {
|
||||||
|
**{alias: lora_scanner for alias in lora_type_aliases},
|
||||||
|
"checkpoint": checkpoint_scanner,
|
||||||
|
"textualinversion": embedding_scanner,
|
||||||
|
}
|
||||||
|
|
||||||
|
versions: list[dict] = []
|
||||||
|
for model in models:
|
||||||
|
if not isinstance(model, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_type = str(model.get("type", "")).lower()
|
||||||
|
if model_type not in normalized_allowed_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scanner = type_scanner_map.get(model_type)
|
||||||
|
if scanner is None:
|
||||||
|
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
|
||||||
|
|
||||||
|
tags_value = model.get("tags")
|
||||||
|
tags = tags_value if isinstance(tags_value, list) else []
|
||||||
|
model_id = model.get("id")
|
||||||
|
try:
|
||||||
|
model_id_int = int(model_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
model_name = model.get("name", "")
|
||||||
|
|
||||||
|
versions_data = model.get("modelVersions")
|
||||||
|
if not isinstance(versions_data, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for version in versions_data:
|
||||||
|
if not isinstance(version, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
version_id = version.get("id")
|
||||||
|
try:
|
||||||
|
version_id_int = int(version_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if await scanner.check_model_version_exists(version_id_int):
|
||||||
|
continue
|
||||||
|
|
||||||
|
images = version.get("images") or []
|
||||||
|
thumbnail_url = None
|
||||||
|
if images and isinstance(images, list):
|
||||||
|
first_image = images[0]
|
||||||
|
if isinstance(first_image, dict):
|
||||||
|
thumbnail_url = first_image.get("url")
|
||||||
|
|
||||||
|
versions.append(
|
||||||
|
{
|
||||||
|
"modelId": model_id_int,
|
||||||
|
"versionId": version_id_int,
|
||||||
|
"modelName": model_name,
|
||||||
|
"versionName": version.get("name", ""),
|
||||||
|
"type": model.get("type"),
|
||||||
|
"tags": tags,
|
||||||
|
"baseModel": version.get("baseModel"),
|
||||||
|
"thumbnailUrl": thumbnail_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({"success": True, "username": username, "versions": versions})
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
class MetadataArchiveHandler:
|
class MetadataArchiveHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -844,6 +948,7 @@ class MiscHandlerSet:
|
|||||||
"register_nodes": self.node_registry.register_nodes,
|
"register_nodes": self.node_registry.register_nodes,
|
||||||
"get_registry": self.node_registry.get_registry,
|
"get_registry": self.node_registry.get_registry,
|
||||||
"check_model_exists": self.model_library.check_model_exists,
|
"check_model_exists": self.model_library.check_model_exists,
|
||||||
|
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||||
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
||||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||||
|
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||||
"""Fetch image information from Civitai API
|
"""Fetch image information from Civitai API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_id: The Civitai image ID
|
image_id: The Civitai image ID
|
||||||
|
|
||||||
@@ -385,3 +385,37 @@ class CivitaiClient:
|
|||||||
error_msg = f"Error fetching image info: {e}"
|
error_msg = f"Error fetching image info: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Fetch all models for a specific Civitai user."""
|
||||||
|
if not username:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
downloader = await get_downloader()
|
||||||
|
url = f"{self.base_url}/models?username={username}"
|
||||||
|
success, result = await downloader.make_request(
|
||||||
|
'GET',
|
||||||
|
url,
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||||
|
return None
|
||||||
|
|
||||||
|
items = result.get("items") if isinstance(result, dict) else None
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
for model in items:
|
||||||
|
versions = model.get("modelVersions")
|
||||||
|
if not isinstance(versions, list):
|
||||||
|
continue
|
||||||
|
for version in versions:
|
||||||
|
self._remove_comfy_metadata(version)
|
||||||
|
|
||||||
|
return items
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Error fetching models for %s: %s", username, exc)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Tuple, Any
|
from typing import Optional, Dict, Tuple, Any, List
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
|
|||||||
"""Fetch model version metadata"""
|
"""Fetch model version metadata"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Fetch models owned by the specified user"""
|
||||||
|
pass
|
||||||
|
|
||||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses Civitai API for metadata"""
|
"""Provider that uses Civitai API for metadata"""
|
||||||
|
|
||||||
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
|||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
return await self.client.get_model_version_info(version_id)
|
return await self.client.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
return await self.client.get_user_models(username)
|
||||||
|
|
||||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||||
|
|
||||||
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
|||||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||||
return None, "CivArchive provider requires both model_id and version_id"
|
return None, "CivArchive provider requires both model_id and version_id"
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Not supported by CivArchive provider"""
|
||||||
|
return None
|
||||||
|
|
||||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses SQLite database for metadata"""
|
"""Provider that uses SQLite database for metadata"""
|
||||||
|
|
||||||
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
"""Fetch model version metadata from SQLite database"""
|
"""Fetch model version metadata from SQLite database"""
|
||||||
async with self._aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = self._aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# Get version details
|
# Get version details
|
||||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||||
cursor = await db.execute(version_query, (version_id,))
|
cursor = await db.execute(version_query, (version_id,))
|
||||||
version_row = await cursor.fetchone()
|
version_row = await cursor.fetchone()
|
||||||
|
|
||||||
if not version_row:
|
if not version_row:
|
||||||
return None, "Model version not found"
|
return None, "Model version not found"
|
||||||
|
|
||||||
model_id = version_row['model_id']
|
model_id = version_row['model_id']
|
||||||
|
|
||||||
# Build complete version data with model info
|
# Build complete version data with model info
|
||||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||||
return version_data, None
|
return version_data, None
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Listing models by username is not supported for archive database"""
|
||||||
|
return None
|
||||||
|
|
||||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||||
"""Helper to build version data with model information"""
|
"""Helper to build version data with model information"""
|
||||||
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
|||||||
continue
|
continue
|
||||||
return None, "No provider could retrieve the data"
|
return None, "No provider could retrieve the data"
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
for provider in self.providers:
|
||||||
|
try:
|
||||||
|
result = await provider.get_user_models(username)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Provider failed for get_user_models: {e}")
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
class ModelMetadataProviderManager:
|
class ModelMetadataProviderManager:
|
||||||
"""Manager for selecting and using model metadata providers"""
|
"""Manager for selecting and using model metadata providers"""
|
||||||
|
|
||||||
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
|
|||||||
"""Fetch model version info using specified or default provider"""
|
"""Fetch model version info using specified or default provider"""
|
||||||
provider = self._get_provider(provider_name)
|
provider = self._get_provider(provider_name)
|
||||||
return await provider.get_model_version_info(version_id)
|
return await provider.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
|
||||||
|
"""Fetch models owned by the specified user"""
|
||||||
|
provider = self._get_provider(provider_name)
|
||||||
|
return await provider.get_user_models(username)
|
||||||
|
|
||||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||||
"""Get provider by name or default provider"""
|
"""Get provider by name or default provider"""
|
||||||
|
|||||||
@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
|||||||
# Valid Lora types
|
# Valid Lora types
|
||||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||||
|
|
||||||
|
# Supported Civitai model types for user model queries (case-insensitive)
|
||||||
|
CIVITAI_USER_MODEL_TYPES = [
|
||||||
|
*VALID_LORA_TYPES,
|
||||||
|
'textualinversion',
|
||||||
|
'checkpoint',
|
||||||
|
]
|
||||||
|
|
||||||
# Auto-organize settings
|
# Auto-organize settings
|
||||||
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from aiohttp import web
|
|||||||
|
|
||||||
from py.routes.handlers.misc_handlers import (
|
from py.routes.handlers.misc_handlers import (
|
||||||
LoraCodeHandler,
|
LoraCodeHandler,
|
||||||
|
ModelLibraryHandler,
|
||||||
NodeRegistry,
|
NodeRegistry,
|
||||||
NodeRegistryHandler,
|
NodeRegistryHandler,
|
||||||
ServiceRegistryAdapter,
|
ServiceRegistryAdapter,
|
||||||
@@ -266,10 +267,34 @@ async def fake_scanner_factory():
|
|||||||
return FakeScanner()
|
return FakeScanner()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeExistenceScanner:
|
||||||
|
def __init__(self, existing=None):
|
||||||
|
self._existing = set(existing or [])
|
||||||
|
|
||||||
|
async def check_model_version_exists(self, version_id):
|
||||||
|
return version_id in self._existing
|
||||||
|
|
||||||
|
async def get_model_versions_by_id(self, _model_id):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class FakeMetadataProvider:
|
class FakeMetadataProvider:
|
||||||
async def get_model_versions(self, _model_id):
|
async def get_model_versions(self, _model_id):
|
||||||
return {"modelVersions": [], "name": "", "type": "lora"}
|
return {"modelVersions": [], "name": "", "type": "lora"}
|
||||||
|
|
||||||
|
async def get_user_models(self, _username):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class FakeUserModelsProvider(FakeMetadataProvider):
|
||||||
|
def __init__(self, models):
|
||||||
|
self.models = models
|
||||||
|
self.received_usernames: list[str] = []
|
||||||
|
|
||||||
|
async def get_user_models(self, username):
|
||||||
|
self.received_usernames.append(username)
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
|
||||||
async def fake_metadata_provider_factory():
|
async def fake_metadata_provider_factory():
|
||||||
return FakeMetadataProvider()
|
return FakeMetadataProvider()
|
||||||
@@ -339,6 +364,167 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
|||||||
assert set(mapping.keys()) == expected_names
|
assert set(mapping.keys()) == expected_names
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_civitai_user_models_filters_versions():
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Model A",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"name": "v1",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [{"url": "http://example.com/a1.jpg"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 101,
|
||||||
|
"name": "v2",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [{"url": "http://example.com/a2.jpg"}],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "Embedding",
|
||||||
|
"type": "TextualInversion",
|
||||||
|
"tags": ["embedding"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 200,
|
||||||
|
"name": "v1",
|
||||||
|
"baseModel": None,
|
||||||
|
"images": [{"url": "http://example.com/e1.jpg"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 202,
|
||||||
|
"name": "v2",
|
||||||
|
"baseModel": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"name": "Checkpoint",
|
||||||
|
"type": "Checkpoint",
|
||||||
|
"tags": ["checkpoint"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 300,
|
||||||
|
"name": "v1",
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"images": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4,
|
||||||
|
"name": "Unsupported",
|
||||||
|
"type": "Other",
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 400,
|
||||||
|
"name": "v1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
provider = FakeUserModelsProvider(models)
|
||||||
|
|
||||||
|
async def provider_factory():
|
||||||
|
return provider
|
||||||
|
|
||||||
|
lora_scanner = FakeExistenceScanner({101})
|
||||||
|
checkpoint_scanner = FakeExistenceScanner()
|
||||||
|
embedding_scanner = FakeExistenceScanner({202})
|
||||||
|
|
||||||
|
async def lora_factory():
|
||||||
|
return lora_scanner
|
||||||
|
|
||||||
|
async def checkpoint_factory():
|
||||||
|
return checkpoint_scanner
|
||||||
|
|
||||||
|
async def embedding_factory():
|
||||||
|
return embedding_scanner
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=lora_factory,
|
||||||
|
get_checkpoint_scanner=checkpoint_factory,
|
||||||
|
get_embedding_scanner=embedding_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert payload["username"] == "pixel"
|
||||||
|
assert payload["versions"] == [
|
||||||
|
{
|
||||||
|
"modelId": 1,
|
||||||
|
"versionId": 100,
|
||||||
|
"modelName": "Model A",
|
||||||
|
"versionName": "v1",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": 2,
|
||||||
|
"versionId": 200,
|
||||||
|
"modelName": "Embedding",
|
||||||
|
"versionName": "v1",
|
||||||
|
"type": "TextualInversion",
|
||||||
|
"tags": ["embedding"],
|
||||||
|
"baseModel": None,
|
||||||
|
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": 3,
|
||||||
|
"versionId": 300,
|
||||||
|
"modelName": "Checkpoint",
|
||||||
|
"versionName": "v1",
|
||||||
|
"type": "Checkpoint",
|
||||||
|
"tags": ["checkpoint"],
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"thumbnailUrl": None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert provider.received_usernames == ["pixel"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_civitai_user_models_requires_username():
|
||||||
|
provider = FakeUserModelsProvider([])
|
||||||
|
|
||||||
|
async def provider_factory():
|
||||||
|
return provider
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.get_civitai_user_models(FakeRequest())
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert payload["success"] is False
|
||||||
|
assert "username" in payload["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_handler_mapping_caches_result():
|
def test_ensure_handler_mapping_caches_result():
|
||||||
call_records = []
|
call_records = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user