mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-29 08:58:53 -03:00
feat: Dynamic base model fetching from Civitai API (#854)
Implement automatic fetching of base models from Civitai API to keep data up-to-date without manual updates. Backend: - Add CivitaiBaseModelService with 7-day TTL caching - Add /api/lm/base-models endpoints for fetching and refreshing - Merge hardcoded and remote models for backward compatibility - Smart abbreviation generation for unknown models Frontend: - Add civitaiBaseModelApi client for API communication - Dynamic base model loading on app initialization - Update SettingsManager to use merged model lists - Add support for 8 new models: Anima, CogVideoX, LTXV 2.3, Mochi, Pony V7, Wan Video 2.5 T2V/I2V API Endpoints: - GET /api/lm/base-models - Get merged models - POST /api/lm/base-models/refresh - Force refresh - GET /api/lm/base-models/categories - Get categories - GET /api/lm/base-models/cache-status - Check cache status Closes #854
This commit is contained in:
141
py/routes/handlers/base_model_handlers.py
Normal file
141
py/routes/handlers/base_model_handlers.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Handlers for base model related endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from ...services.civitai_base_model_service import get_civitai_base_model_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelHandlerSet:
|
||||||
|
"""Collection of handlers for base model operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_model_service_factory: Callable[[], Any] = get_civitai_base_model_service,
|
||||||
|
) -> None:
|
||||||
|
self._base_model_service_factory = base_model_service_factory
|
||||||
|
|
||||||
|
def to_route_mapping(
|
||||||
|
self,
|
||||||
|
) -> Dict[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||||
|
"""Return mapping of route names to handler methods."""
|
||||||
|
return {
|
||||||
|
"get_base_models": self.get_base_models,
|
||||||
|
"refresh_base_models": self.refresh_base_models,
|
||||||
|
"get_base_model_categories": self.get_base_model_categories,
|
||||||
|
"get_base_model_cache_status": self.get_base_model_cache_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get merged base models (hardcoded + remote from Civitai).
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
refresh: If 'true', force refresh from API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response with:
|
||||||
|
- models: List of base model names
|
||||||
|
- source: 'cache', 'api', or 'fallback'
|
||||||
|
- last_updated: ISO timestamp
|
||||||
|
- counts: hardcoded_count, remote_count, merged_count
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = await self._base_model_service_factory()
|
||||||
|
|
||||||
|
# Check for refresh parameter
|
||||||
|
force_refresh = request.query.get("refresh", "").lower() == "true"
|
||||||
|
|
||||||
|
result = await service.get_base_models(force_refresh=force_refresh)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"data": result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_base_models: {e}")
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": str(e)},
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def refresh_base_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Force refresh base models from Civitai API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response with refreshed data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = await self._base_model_service_factory()
|
||||||
|
result = await service.refresh_cache()
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"data": result,
|
||||||
|
"message": "Base models cache refreshed successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in refresh_base_models: {e}")
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": str(e)},
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_base_model_categories(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get categorized base models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response with categorized models
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = await self._base_model_service_factory()
|
||||||
|
categories = service.get_model_categories()
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"data": categories,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_base_model_categories: {e}")
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": str(e)},
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_base_model_cache_status(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get cache status for base models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON response with cache status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = await self._base_model_service_factory()
|
||||||
|
status = service.get_cache_status()
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"data": status,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_base_model_cache_status: {e}")
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": str(e)},
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
@@ -40,6 +40,7 @@ from ...utils.civitai_utils import rewrite_preview_url
|
|||||||
from ...utils.example_images_paths import is_valid_example_images_root
|
from ...utils.example_images_paths import is_valid_example_images_root
|
||||||
from ...utils.lora_metadata import extract_trained_words
|
from ...utils.lora_metadata import extract_trained_words
|
||||||
from ...utils.usage_stats import UsageStats
|
from ...utils.usage_stats import UsageStats
|
||||||
|
from .base_model_handlers import BaseModelHandlerSet
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1618,6 +1619,7 @@ class MiscHandlerSet:
|
|||||||
custom_words: CustomWordsHandler,
|
custom_words: CustomWordsHandler,
|
||||||
supporters: SupportersHandler,
|
supporters: SupportersHandler,
|
||||||
example_workflows: ExampleWorkflowsHandler,
|
example_workflows: ExampleWorkflowsHandler,
|
||||||
|
base_model: BaseModelHandlerSet,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.health = health
|
self.health = health
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
@@ -1632,6 +1634,7 @@ class MiscHandlerSet:
|
|||||||
self.custom_words = custom_words
|
self.custom_words = custom_words
|
||||||
self.supporters = supporters
|
self.supporters = supporters
|
||||||
self.example_workflows = example_workflows
|
self.example_workflows = example_workflows
|
||||||
|
self.base_model = base_model
|
||||||
|
|
||||||
def to_route_mapping(
|
def to_route_mapping(
|
||||||
self,
|
self,
|
||||||
@@ -1663,6 +1666,11 @@ class MiscHandlerSet:
|
|||||||
"get_supporters": self.supporters.get_supporters,
|
"get_supporters": self.supporters.get_supporters,
|
||||||
"get_example_workflows": self.example_workflows.get_example_workflows,
|
"get_example_workflows": self.example_workflows.get_example_workflows,
|
||||||
"get_example_workflow": self.example_workflows.get_example_workflow,
|
"get_example_workflow": self.example_workflows.get_example_workflow,
|
||||||
|
# Base model handlers
|
||||||
|
"get_base_models": self.base_model.get_base_models,
|
||||||
|
"refresh_base_models": self.base_model.refresh_base_models,
|
||||||
|
"get_base_model_categories": self.base_model.get_base_model_categories,
|
||||||
|
"get_base_model_cache_status": self.base_model.get_base_model_cache_status,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,15 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition(
|
RouteDefinition(
|
||||||
"GET", "/api/lm/example-workflows/{filename}", "get_example_workflow"
|
"GET", "/api/lm/example-workflows/{filename}", "get_example_workflow"
|
||||||
),
|
),
|
||||||
|
# Base model management routes
|
||||||
|
RouteDefinition("GET", "/api/lm/base-models", "get_base_models"),
|
||||||
|
RouteDefinition("POST", "/api/lm/base-models/refresh", "refresh_base_models"),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/base-models/categories", "get_base_model_categories"
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/base-models/cache-status", "get_base_model_cache_status"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .handlers.misc_handlers import (
|
|||||||
UsageStatsHandler,
|
UsageStatsHandler,
|
||||||
build_service_registry_adapter,
|
build_service_registry_adapter,
|
||||||
)
|
)
|
||||||
|
from .handlers.base_model_handlers import BaseModelHandlerSet
|
||||||
from .misc_route_registrar import MiscRouteRegistrar
|
from .misc_route_registrar import MiscRouteRegistrar
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -128,6 +129,7 @@ class MiscRoutes:
|
|||||||
custom_words = CustomWordsHandler()
|
custom_words = CustomWordsHandler()
|
||||||
supporters = SupportersHandler()
|
supporters = SupportersHandler()
|
||||||
example_workflows = ExampleWorkflowsHandler()
|
example_workflows = ExampleWorkflowsHandler()
|
||||||
|
base_model = BaseModelHandlerSet()
|
||||||
|
|
||||||
return self._handler_set_factory(
|
return self._handler_set_factory(
|
||||||
health=health,
|
health=health,
|
||||||
@@ -143,6 +145,7 @@ class MiscRoutes:
|
|||||||
custom_words=custom_words,
|
custom_words=custom_words,
|
||||||
supporters=supporters,
|
supporters=supporters,
|
||||||
example_workflows=example_workflows,
|
example_workflows=example_workflows,
|
||||||
|
base_model=base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
430
py/services/civitai_base_model_service.py
Normal file
430
py/services/civitai_base_model_service.py
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from ..utils.constants import SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS
|
||||||
|
from .downloader import get_downloader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CivitaiBaseModelService:
|
||||||
|
"""Service for fetching and managing Civitai base models.
|
||||||
|
|
||||||
|
This service provides:
|
||||||
|
- Fetching base models from Civitai API
|
||||||
|
- Caching with TTL (7 days default)
|
||||||
|
- Merging hardcoded and remote base models
|
||||||
|
- Generating abbreviations for new/unknown models
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional[CivitaiBaseModelService] = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Default TTL for cache in seconds (7 days)
|
||||||
|
DEFAULT_CACHE_TTL = 7 * 24 * 60 * 60
|
||||||
|
|
||||||
|
# Civitai API endpoint for enums
|
||||||
|
CIVITAI_ENUMS_URL = "https://civitai.com/api/v1/enums"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls) -> CivitaiBaseModelService:
|
||||||
|
"""Get singleton instance of the service."""
|
||||||
|
async with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the service."""
|
||||||
|
if hasattr(self, "_initialized"):
|
||||||
|
return
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
# Cache storage
|
||||||
|
self._cache: Optional[Dict[str, Any]] = None
|
||||||
|
self._cache_timestamp: Optional[datetime] = None
|
||||||
|
self._cache_ttl = self.DEFAULT_CACHE_TTL
|
||||||
|
|
||||||
|
# Hardcoded models for fallback
|
||||||
|
self._hardcoded_models = set(SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS)
|
||||||
|
|
||||||
|
logger.info("CivitaiBaseModelService initialized")
|
||||||
|
|
||||||
|
async def get_base_models(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||||
|
"""Get merged base models (hardcoded + remote).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_refresh: If True, fetch from API regardless of cache state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- models: List of merged base model names
|
||||||
|
- source: 'cache', 'api', or 'fallback'
|
||||||
|
- last_updated: ISO timestamp of last successful API fetch
|
||||||
|
- hardcoded_count: Number of hardcoded models
|
||||||
|
- remote_count: Number of remote models
|
||||||
|
- merged_count: Total unique models
|
||||||
|
"""
|
||||||
|
# Check if cache is valid
|
||||||
|
if not force_refresh and self._is_cache_valid():
|
||||||
|
logger.debug("Returning cached base models")
|
||||||
|
return self._build_response("cache")
|
||||||
|
|
||||||
|
# Try to fetch from API
|
||||||
|
try:
|
||||||
|
remote_models = await self._fetch_from_civitai()
|
||||||
|
if remote_models:
|
||||||
|
self._update_cache(remote_models)
|
||||||
|
return self._build_response("api")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch base models from Civitai: {e}")
|
||||||
|
|
||||||
|
# Fallback to hardcoded models
|
||||||
|
return self._build_response("fallback")
|
||||||
|
|
||||||
|
async def refresh_cache(self) -> Dict[str, Any]:
|
||||||
|
"""Force refresh the cache from Civitai API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dict same as get_base_models()
|
||||||
|
"""
|
||||||
|
return await self.get_base_models(force_refresh=True)
|
||||||
|
|
||||||
|
def get_cache_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get current cache status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- has_cache: Whether cache exists
|
||||||
|
- last_updated: ISO timestamp or None
|
||||||
|
- is_expired: Whether cache is expired
|
||||||
|
- ttl_seconds: TTL in seconds
|
||||||
|
- age_seconds: Age of cache in seconds (if exists)
|
||||||
|
"""
|
||||||
|
if self._cache is None or self._cache_timestamp is None:
|
||||||
|
return {
|
||||||
|
"has_cache": False,
|
||||||
|
"last_updated": None,
|
||||||
|
"is_expired": True,
|
||||||
|
"ttl_seconds": self._cache_ttl,
|
||||||
|
"age_seconds": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds()
|
||||||
|
return {
|
||||||
|
"has_cache": True,
|
||||||
|
"last_updated": self._cache_timestamp.isoformat(),
|
||||||
|
"is_expired": age > self._cache_ttl,
|
||||||
|
"ttl_seconds": self._cache_ttl,
|
||||||
|
"age_seconds": int(age),
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_abbreviation(self, model_name: str) -> str:
|
||||||
|
"""Generate abbreviation for a base model name.
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
1. Extract version patterns (e.g., "2.5" from "Wan Video 2.5")
|
||||||
|
2. Extract main acronym (e.g., "SD" from "SD 1.5")
|
||||||
|
3. Handle special cases (Flux, Wan, etc.)
|
||||||
|
4. Fallback to first letters of words (max 4 chars)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Full base model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated abbreviation (max 4 characters)
|
||||||
|
"""
|
||||||
|
if not model_name or not isinstance(model_name, str):
|
||||||
|
return "OTH"
|
||||||
|
|
||||||
|
name = model_name.strip()
|
||||||
|
if not name:
|
||||||
|
return "OTH"
|
||||||
|
|
||||||
|
# Check if it's already in hardcoded abbreviations
|
||||||
|
# This is a simplified check - in practice you'd have a mapping
|
||||||
|
lower_name = name.lower()
|
||||||
|
|
||||||
|
# Special cases
|
||||||
|
special_cases = {
|
||||||
|
"sd 1.4": "SD1",
|
||||||
|
"sd 1.5": "SD1",
|
||||||
|
"sd 1.5 lcm": "SD1",
|
||||||
|
"sd 1.5 hyper": "SD1",
|
||||||
|
"sd 2.0": "SD2",
|
||||||
|
"sd 2.1": "SD2",
|
||||||
|
"sd 3": "SD3",
|
||||||
|
"sd 3.5": "SD3",
|
||||||
|
"sd 3.5 medium": "SD3",
|
||||||
|
"sd 3.5 large": "SD3",
|
||||||
|
"sd 3.5 large turbo": "SD3",
|
||||||
|
"sdxl 1.0": "XL",
|
||||||
|
"sdxl lightning": "XL",
|
||||||
|
"sdxl hyper": "XL",
|
||||||
|
"flux.1 d": "F1D",
|
||||||
|
"flux.1 s": "F1S",
|
||||||
|
"flux.1 krea": "F1KR",
|
||||||
|
"flux.1 kontext": "F1KX",
|
||||||
|
"flux.2 d": "F2D",
|
||||||
|
"flux.2 klein 9b": "FK9",
|
||||||
|
"flux.2 klein 9b-base": "FK9B",
|
||||||
|
"flux.2 klein 4b": "FK4",
|
||||||
|
"flux.2 klein 4b-base": "FK4B",
|
||||||
|
"auraflow": "AF",
|
||||||
|
"chroma": "CHR",
|
||||||
|
"pixart a": "PXA",
|
||||||
|
"pixart e": "PXE",
|
||||||
|
"hunyuan 1": "HY",
|
||||||
|
"hunyuan video": "HYV",
|
||||||
|
"lumina": "L",
|
||||||
|
"kolors": "KLR",
|
||||||
|
"noobai": "NAI",
|
||||||
|
"illustrious": "IL",
|
||||||
|
"pony": "PONY",
|
||||||
|
"pony v7": "PNY7",
|
||||||
|
"hidream": "HID",
|
||||||
|
"qwen": "QWEN",
|
||||||
|
"zimageturbo": "ZIT",
|
||||||
|
"zimagebase": "ZIB",
|
||||||
|
"anima": "ANI",
|
||||||
|
"svd": "SVD",
|
||||||
|
"ltxv": "LTXV",
|
||||||
|
"ltxv2": "LTV2",
|
||||||
|
"ltxv 2.3": "LTX",
|
||||||
|
"cogvideox": "CVX",
|
||||||
|
"mochi": "MCHI",
|
||||||
|
"wan video": "WAN",
|
||||||
|
"wan video 1.3b t2v": "WAN",
|
||||||
|
"wan video 14b t2v": "WAN",
|
||||||
|
"wan video 14b i2v 480p": "WAN",
|
||||||
|
"wan video 14b i2v 720p": "WAN",
|
||||||
|
"wan video 2.2 ti2v-5b": "WAN",
|
||||||
|
"wan video 2.2 t2v-a14b": "WAN",
|
||||||
|
"wan video 2.2 i2v-a14b": "WAN",
|
||||||
|
"wan video 2.5 t2v": "WAN",
|
||||||
|
"wan video 2.5 i2v": "WAN",
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower_name in special_cases:
|
||||||
|
return special_cases[lower_name]
|
||||||
|
|
||||||
|
# Try to extract acronym from version pattern
|
||||||
|
# e.g., "Model Name 2.5" -> "MN25"
|
||||||
|
version_match = re.search(r"(\d+(?:\.\d+)?)", name)
|
||||||
|
version = version_match.group(1) if version_match else ""
|
||||||
|
|
||||||
|
# Remove version and common words
|
||||||
|
words = re.sub(r"\d+(?:\.\d+)?", "", name)
|
||||||
|
words = re.sub(
|
||||||
|
r"\b(model|video|diffusion|checkpoint|textualinversion)\b",
|
||||||
|
"",
|
||||||
|
words,
|
||||||
|
flags=re.I,
|
||||||
|
)
|
||||||
|
words = words.strip()
|
||||||
|
|
||||||
|
# Get first letters of remaining words
|
||||||
|
tokens = re.findall(r"[A-Za-z]+", words)
|
||||||
|
if tokens:
|
||||||
|
# Build abbreviation from first letters
|
||||||
|
abbrev = "".join(token[0].upper() for token in tokens)
|
||||||
|
# Add version if present
|
||||||
|
if version:
|
||||||
|
# Clean version (remove dots for abbreviation)
|
||||||
|
version_clean = version.replace(".", "")
|
||||||
|
abbrev = abbrev[: 4 - len(version_clean)] + version_clean
|
||||||
|
return abbrev[:4]
|
||||||
|
|
||||||
|
# Final fallback: just take first 4 alphanumeric chars
|
||||||
|
alphanumeric = re.sub(r"[^A-Za-z0-9]", "", name)
|
||||||
|
if alphanumeric:
|
||||||
|
return alphanumeric[:4].upper()
|
||||||
|
|
||||||
|
return "OTH"
|
||||||
|
|
||||||
|
async def _fetch_from_civitai(self) -> Optional[Set[str]]:
|
||||||
|
"""Fetch base models from Civitai API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of base model names, or None if failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
downloader = await get_downloader()
|
||||||
|
success, result = await downloader.make_request(
|
||||||
|
"GET",
|
||||||
|
self.CIVITAI_ENUMS_URL,
|
||||||
|
use_auth=False, # enums endpoint doesn't require auth
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Failed to fetch enums from Civitai: {result}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
data = json.loads(result)
|
||||||
|
else:
|
||||||
|
data = result
|
||||||
|
|
||||||
|
# Extract base models from response
|
||||||
|
base_models = set()
|
||||||
|
|
||||||
|
# Use ActiveBaseModel if available (recommended active models)
|
||||||
|
if "ActiveBaseModel" in data:
|
||||||
|
base_models.update(data["ActiveBaseModel"])
|
||||||
|
logger.info(f"Fetched {len(base_models)} models from ActiveBaseModel")
|
||||||
|
# Fallback to full BaseModel list
|
||||||
|
elif "BaseModel" in data:
|
||||||
|
base_models.update(data["BaseModel"])
|
||||||
|
logger.info(f"Fetched {len(base_models)} models from BaseModel")
|
||||||
|
else:
|
||||||
|
logger.warning("No base model data found in Civitai response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return base_models
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching from Civitai: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _update_cache(self, remote_models: Set[str]) -> None:
|
||||||
|
"""Update internal cache with fetched models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_models: Set of base model names from API
|
||||||
|
"""
|
||||||
|
self._cache = {
|
||||||
|
"remote_models": sorted(remote_models),
|
||||||
|
"hardcoded_models": sorted(self._hardcoded_models),
|
||||||
|
}
|
||||||
|
self._cache_timestamp = datetime.now(timezone.utc)
|
||||||
|
logger.info(f"Cache updated with {len(remote_models)} remote models")
|
||||||
|
|
||||||
|
def _is_cache_valid(self) -> bool:
|
||||||
|
"""Check if current cache is valid (not expired).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if cache exists and is not expired
|
||||||
|
"""
|
||||||
|
if self._cache is None or self._cache_timestamp is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds()
|
||||||
|
return age <= self._cache_ttl
|
||||||
|
|
||||||
|
def _build_response(self, source: str) -> Dict[str, Any]:
|
||||||
|
"""Build response dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: 'cache', 'api', or 'fallback'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dictionary
|
||||||
|
"""
|
||||||
|
if source == "fallback" or self._cache is None:
|
||||||
|
# Use only hardcoded models
|
||||||
|
merged = sorted(self._hardcoded_models)
|
||||||
|
return {
|
||||||
|
"models": merged,
|
||||||
|
"source": source,
|
||||||
|
"last_updated": None,
|
||||||
|
"hardcoded_count": len(self._hardcoded_models),
|
||||||
|
"remote_count": 0,
|
||||||
|
"merged_count": len(merged),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge hardcoded and remote models
|
||||||
|
remote_set = set(self._cache.get("remote_models", []))
|
||||||
|
merged = sorted(self._hardcoded_models | remote_set)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"models": merged,
|
||||||
|
"source": source,
|
||||||
|
"last_updated": self._cache_timestamp.isoformat()
|
||||||
|
if self._cache_timestamp
|
||||||
|
else None,
|
||||||
|
"hardcoded_count": len(self._hardcoded_models),
|
||||||
|
"remote_count": len(remote_set),
|
||||||
|
"merged_count": len(merged),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_categories(self) -> Dict[str, List[str]]:
|
||||||
|
"""Get categorized base models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping category names to lists of model names
|
||||||
|
"""
|
||||||
|
# Define category patterns
|
||||||
|
categories = {
|
||||||
|
"Stable Diffusion 1.x": ["SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper"],
|
||||||
|
"Stable Diffusion 2.x": ["SD 2.0", "SD 2.1"],
|
||||||
|
"Stable Diffusion 3.x": [
|
||||||
|
"SD 3",
|
||||||
|
"SD 3.5",
|
||||||
|
"SD 3.5 Medium",
|
||||||
|
"SD 3.5 Large",
|
||||||
|
"SD 3.5 Large Turbo",
|
||||||
|
],
|
||||||
|
"SDXL": ["SDXL 1.0", "SDXL Lightning", "SDXL Hyper"],
|
||||||
|
"Flux Models": [
|
||||||
|
"Flux.1 D",
|
||||||
|
"Flux.1 S",
|
||||||
|
"Flux.1 Krea",
|
||||||
|
"Flux.1 Kontext",
|
||||||
|
"Flux.2 D",
|
||||||
|
"Flux.2 Klein 9B",
|
||||||
|
"Flux.2 Klein 9B-base",
|
||||||
|
"Flux.2 Klein 4B",
|
||||||
|
"Flux.2 Klein 4B-base",
|
||||||
|
],
|
||||||
|
"Video Models": [
|
||||||
|
"SVD",
|
||||||
|
"LTXV",
|
||||||
|
"LTXV2",
|
||||||
|
"LTXV 2.3",
|
||||||
|
"CogVideoX",
|
||||||
|
"Mochi",
|
||||||
|
"Hunyuan Video",
|
||||||
|
"Wan Video",
|
||||||
|
"Wan Video 1.3B t2v",
|
||||||
|
"Wan Video 14B t2v",
|
||||||
|
"Wan Video 14B i2v 480p",
|
||||||
|
"Wan Video 14B i2v 720p",
|
||||||
|
"Wan Video 2.2 TI2V-5B",
|
||||||
|
"Wan Video 2.2 T2V-A14B",
|
||||||
|
"Wan Video 2.2 I2V-A14B",
|
||||||
|
"Wan Video 2.5 T2V",
|
||||||
|
"Wan Video 2.5 I2V",
|
||||||
|
],
|
||||||
|
"Other Models": [
|
||||||
|
"Illustrious",
|
||||||
|
"Pony",
|
||||||
|
"Pony V7",
|
||||||
|
"HiDream",
|
||||||
|
"Qwen",
|
||||||
|
"AuraFlow",
|
||||||
|
"Chroma",
|
||||||
|
"ZImageTurbo",
|
||||||
|
"ZImageBase",
|
||||||
|
"PixArt a",
|
||||||
|
"PixArt E",
|
||||||
|
"Hunyuan 1",
|
||||||
|
"Lumina",
|
||||||
|
"Kolors",
|
||||||
|
"NoobAI",
|
||||||
|
"Anima",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return categories
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience function for getting the singleton instance
|
||||||
|
async def get_civitai_base_model_service() -> CivitaiBaseModelService:
|
||||||
|
"""Get the singleton instance of CivitaiBaseModelService."""
|
||||||
|
return await CivitaiBaseModelService.get_instance()
|
||||||
@@ -110,6 +110,8 @@ DIFFUSION_MODEL_BASE_MODELS = frozenset(
|
|||||||
"Wan Video 2.2 T2V-A14B",
|
"Wan Video 2.2 T2V-A14B",
|
||||||
"Wan Video 2.5 T2V",
|
"Wan Video 2.5 T2V",
|
||||||
"Wan Video 2.5 I2V",
|
"Wan Video 2.5 I2V",
|
||||||
|
"CogVideoX",
|
||||||
|
"Mochi",
|
||||||
"Qwen",
|
"Qwen",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -151,6 +153,7 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
|
|||||||
"NoobAI",
|
"NoobAI",
|
||||||
"Illustrious",
|
"Illustrious",
|
||||||
"Pony",
|
"Pony",
|
||||||
|
"Pony V7",
|
||||||
"HiDream",
|
"HiDream",
|
||||||
"Qwen",
|
"Qwen",
|
||||||
"ZImageTurbo",
|
"ZImageTurbo",
|
||||||
@@ -158,6 +161,9 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
|
|||||||
"SVD",
|
"SVD",
|
||||||
"LTXV",
|
"LTXV",
|
||||||
"LTXV2",
|
"LTXV2",
|
||||||
|
"LTXV 2.3",
|
||||||
|
"CogVideoX",
|
||||||
|
"Mochi",
|
||||||
"Wan Video",
|
"Wan Video",
|
||||||
"Wan Video 1.3B t2v",
|
"Wan Video 1.3B t2v",
|
||||||
"Wan Video 14B t2v",
|
"Wan Video 14B t2v",
|
||||||
@@ -166,6 +172,9 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
|
|||||||
"Wan Video 2.2 TI2V-5B",
|
"Wan Video 2.2 TI2V-5B",
|
||||||
"Wan Video 2.2 T2V-A14B",
|
"Wan Video 2.2 T2V-A14B",
|
||||||
"Wan Video 2.2 I2V-A14B",
|
"Wan Video 2.2 I2V-A14B",
|
||||||
|
"Wan Video 2.5 T2V",
|
||||||
|
"Wan Video 2.5 I2V",
|
||||||
"Hunyuan Video",
|
"Hunyuan Video",
|
||||||
|
"Anima",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
167
refs/enums.json
Normal file
167
refs/enums.json
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
{
|
||||||
|
"ModelType": [
|
||||||
|
"Checkpoint",
|
||||||
|
"TextualInversion",
|
||||||
|
"Hypernetwork",
|
||||||
|
"AestheticGradient",
|
||||||
|
"LORA",
|
||||||
|
"LoCon",
|
||||||
|
"DoRA",
|
||||||
|
"Controlnet",
|
||||||
|
"Upscaler",
|
||||||
|
"MotionModule",
|
||||||
|
"VAE",
|
||||||
|
"Poses",
|
||||||
|
"Wildcards",
|
||||||
|
"Workflows",
|
||||||
|
"Detection",
|
||||||
|
"Other"
|
||||||
|
],
|
||||||
|
"ModelFileType": [
|
||||||
|
"Model",
|
||||||
|
"Text Encoder",
|
||||||
|
"Pruned Model",
|
||||||
|
"Negative",
|
||||||
|
"Training Data",
|
||||||
|
"VAE",
|
||||||
|
"Config",
|
||||||
|
"Archive"
|
||||||
|
],
|
||||||
|
"ActiveBaseModel": [
|
||||||
|
"Anima",
|
||||||
|
"AuraFlow",
|
||||||
|
"Chroma",
|
||||||
|
"CogVideoX",
|
||||||
|
"Flux.1 S",
|
||||||
|
"Flux.1 D",
|
||||||
|
"Flux.1 Krea",
|
||||||
|
"Flux.1 Kontext",
|
||||||
|
"Flux.2 D",
|
||||||
|
"Flux.2 Klein 9B",
|
||||||
|
"Flux.2 Klein 9B-base",
|
||||||
|
"Flux.2 Klein 4B",
|
||||||
|
"Flux.2 Klein 4B-base",
|
||||||
|
"HiDream",
|
||||||
|
"Hunyuan 1",
|
||||||
|
"Hunyuan Video",
|
||||||
|
"Illustrious",
|
||||||
|
"Kolors",
|
||||||
|
"LTXV",
|
||||||
|
"LTXV2",
|
||||||
|
"LTXV 2.3",
|
||||||
|
"Lumina",
|
||||||
|
"Mochi",
|
||||||
|
"NoobAI",
|
||||||
|
"Other",
|
||||||
|
"PixArt a",
|
||||||
|
"PixArt E",
|
||||||
|
"Pony",
|
||||||
|
"Pony V7",
|
||||||
|
"Qwen",
|
||||||
|
"SD 1.4",
|
||||||
|
"SD 1.5",
|
||||||
|
"SD 1.5 LCM",
|
||||||
|
"SD 1.5 Hyper",
|
||||||
|
"SD 2.0",
|
||||||
|
"SD 2.1",
|
||||||
|
"SDXL 1.0",
|
||||||
|
"SDXL Lightning",
|
||||||
|
"SDXL Hyper",
|
||||||
|
"Wan Video 1.3B t2v",
|
||||||
|
"Wan Video 14B t2v",
|
||||||
|
"Wan Video 14B i2v 480p",
|
||||||
|
"Wan Video 14B i2v 720p",
|
||||||
|
"Wan Video 2.2 TI2V-5B",
|
||||||
|
"Wan Video 2.2 I2V-A14B",
|
||||||
|
"Wan Video 2.2 T2V-A14B",
|
||||||
|
"Wan Video 2.5 T2V",
|
||||||
|
"Wan Video 2.5 I2V",
|
||||||
|
"ZImageTurbo",
|
||||||
|
"ZImageBase"
|
||||||
|
],
|
||||||
|
"BaseModel": [
|
||||||
|
"Anima",
|
||||||
|
"AuraFlow",
|
||||||
|
"Chroma",
|
||||||
|
"CogVideoX",
|
||||||
|
"Flux.1 S",
|
||||||
|
"Flux.1 D",
|
||||||
|
"Flux.1 Krea",
|
||||||
|
"Flux.1 Kontext",
|
||||||
|
"Flux.2 D",
|
||||||
|
"Flux.2 Klein 9B",
|
||||||
|
"Flux.2 Klein 9B-base",
|
||||||
|
"Flux.2 Klein 4B",
|
||||||
|
"Flux.2 Klein 4B-base",
|
||||||
|
"HiDream",
|
||||||
|
"Hunyuan 1",
|
||||||
|
"Hunyuan Video",
|
||||||
|
"Illustrious",
|
||||||
|
"Imagen4",
|
||||||
|
"Kling",
|
||||||
|
"Kolors",
|
||||||
|
"LTXV",
|
||||||
|
"LTXV2",
|
||||||
|
"LTXV 2.3",
|
||||||
|
"Lumina",
|
||||||
|
"Mochi",
|
||||||
|
"Nano Banana",
|
||||||
|
"NoobAI",
|
||||||
|
"ODOR",
|
||||||
|
"OpenAI",
|
||||||
|
"Other",
|
||||||
|
"PixArt a",
|
||||||
|
"PixArt E",
|
||||||
|
"Playground v2",
|
||||||
|
"Pony",
|
||||||
|
"Pony V7",
|
||||||
|
"Qwen",
|
||||||
|
"Stable Cascade",
|
||||||
|
"SD 1.4",
|
||||||
|
"SD 1.5",
|
||||||
|
"SD 1.5 LCM",
|
||||||
|
"SD 1.5 Hyper",
|
||||||
|
"SD 2.0",
|
||||||
|
"SD 2.0 768",
|
||||||
|
"SD 2.1",
|
||||||
|
"SD 2.1 768",
|
||||||
|
"SD 2.1 Unclip",
|
||||||
|
"SD 3",
|
||||||
|
"SD 3.5",
|
||||||
|
"SD 3.5 Large",
|
||||||
|
"SD 3.5 Large Turbo",
|
||||||
|
"SD 3.5 Medium",
|
||||||
|
"Sora 2",
|
||||||
|
"SDXL 0.9",
|
||||||
|
"SDXL 1.0",
|
||||||
|
"SDXL 1.0 LCM",
|
||||||
|
"SDXL Lightning",
|
||||||
|
"SDXL Hyper",
|
||||||
|
"SDXL Turbo",
|
||||||
|
"SDXL Distilled",
|
||||||
|
"Seedance",
|
||||||
|
"Seedream",
|
||||||
|
"SVD",
|
||||||
|
"SVD XT",
|
||||||
|
"Veo 3",
|
||||||
|
"Vidu Q1",
|
||||||
|
"Wan Video",
|
||||||
|
"Wan Video 1.3B t2v",
|
||||||
|
"Wan Video 14B t2v",
|
||||||
|
"Wan Video 14B i2v 480p",
|
||||||
|
"Wan Video 14B i2v 720p",
|
||||||
|
"Wan Video 2.2 TI2V-5B",
|
||||||
|
"Wan Video 2.2 I2V-A14B",
|
||||||
|
"Wan Video 2.2 T2V-A14B",
|
||||||
|
"Wan Video 2.5 T2V",
|
||||||
|
"Wan Video 2.5 I2V",
|
||||||
|
"ZImageTurbo",
|
||||||
|
"ZImageBase"
|
||||||
|
],
|
||||||
|
"BaseModelType": [
|
||||||
|
"Standard",
|
||||||
|
"Inpainting",
|
||||||
|
"Refiner",
|
||||||
|
"Pix2Pix"
|
||||||
|
]
|
||||||
|
}
|
||||||
164
static/js/api/civitaiBaseModelApi.js
Normal file
164
static/js/api/civitaiBaseModelApi.js
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
/**
|
||||||
|
* API client for Civitai base model management
|
||||||
|
* Handles fetching and refreshing base models from Civitai API
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { showToast } from '../utils/uiHelpers.js';
|
||||||
|
|
||||||
|
const BASE_MODEL_ENDPOINTS = {
|
||||||
|
getModels: '/api/lm/base-models',
|
||||||
|
refresh: '/api/lm/base-models/refresh',
|
||||||
|
categories: '/api/lm/base-models/categories',
|
||||||
|
cacheStatus: '/api/lm/base-models/cache-status',
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Civitai Base Model API Client
|
||||||
|
*/
|
||||||
|
export class CivitaiBaseModelApi {
|
||||||
|
constructor() {
|
||||||
|
this.cache = null;
|
||||||
|
this.cacheTimestamp = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get base models (with caching)
|
||||||
|
* @param {boolean} forceRefresh - Force refresh from API
|
||||||
|
* @returns {Promise<Object>} Response with models, source, and counts
|
||||||
|
*/
|
||||||
|
async getBaseModels(forceRefresh = false) {
|
||||||
|
try {
|
||||||
|
const url = new URL(BASE_MODEL_ENDPOINTS.getModels, window.location.origin);
|
||||||
|
if (forceRefresh) {
|
||||||
|
url.searchParams.append('refresh', 'true');
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(url);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to fetch base models: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
this.cache = data.data;
|
||||||
|
this.cacheTimestamp = Date.now();
|
||||||
|
return data.data;
|
||||||
|
} else {
|
||||||
|
throw new Error(data.error || 'Failed to fetch base models');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching base models:', error);
|
||||||
|
showToast('Failed to fetch base models', { message: error.message }, 'error');
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Force refresh base models from Civitai API
|
||||||
|
* @returns {Promise<Object>} Refreshed data
|
||||||
|
*/
|
||||||
|
async refreshBaseModels() {
|
||||||
|
try {
|
||||||
|
const response = await fetch(BASE_MODEL_ENDPOINTS.refresh, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' }
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to refresh base models: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
this.cache = data.data;
|
||||||
|
this.cacheTimestamp = Date.now();
|
||||||
|
showToast('Base models refreshed successfully', {}, 'success');
|
||||||
|
return data.data;
|
||||||
|
} else {
|
||||||
|
throw new Error(data.error || 'Failed to refresh base models');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error refreshing base models:', error);
|
||||||
|
showToast('Failed to refresh base models', { message: error.message }, 'error');
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get base model categories
|
||||||
|
* @returns {Promise<Object>} Categories with model lists
|
||||||
|
*/
|
||||||
|
async getCategories() {
|
||||||
|
try {
|
||||||
|
const response = await fetch(BASE_MODEL_ENDPOINTS.categories);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to fetch categories: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
return data.data;
|
||||||
|
} else {
|
||||||
|
throw new Error(data.error || 'Failed to fetch categories');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching categories:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get cache status
|
||||||
|
* @returns {Promise<Object>} Cache status information
|
||||||
|
*/
|
||||||
|
async getCacheStatus() {
|
||||||
|
try {
|
||||||
|
const response = await fetch(BASE_MODEL_ENDPOINTS.cacheStatus);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to fetch cache status: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
return data.data;
|
||||||
|
} else {
|
||||||
|
throw new Error(data.error || 'Failed to fetch cache status');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching cache status:', error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get cached models (if available)
|
||||||
|
* @returns {Object|null} Cached data or null
|
||||||
|
*/
|
||||||
|
getCachedModels() {
|
||||||
|
return this.cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if cache is available
|
||||||
|
* @returns {boolean}
|
||||||
|
*/
|
||||||
|
hasCache() {
|
||||||
|
return this.cache !== null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get cache age in milliseconds
|
||||||
|
* @returns {number|null} Age in ms or null if no cache
|
||||||
|
*/
|
||||||
|
getCacheAge() {
|
||||||
|
if (!this.cacheTimestamp) return null;
|
||||||
|
return Date.now() - this.cacheTimestamp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export singleton instance
|
||||||
|
export const civitaiBaseModelApi = new CivitaiBaseModelApi();
|
||||||
@@ -17,6 +17,8 @@ import { onboardingManager } from './managers/OnboardingManager.js';
|
|||||||
import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js';
|
import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js';
|
||||||
import { createPageContextMenu, createGlobalContextMenu } from './components/ContextMenu/index.js';
|
import { createPageContextMenu, createGlobalContextMenu } from './components/ContextMenu/index.js';
|
||||||
import { initializeEventManagement } from './utils/eventManagementInit.js';
|
import { initializeEventManagement } from './utils/eventManagementInit.js';
|
||||||
|
import { civitaiBaseModelApi } from './api/civitaiBaseModelApi.js';
|
||||||
|
import { setDynamicBaseModels } from './utils/constants.js';
|
||||||
|
|
||||||
// Core application class
|
// Core application class
|
||||||
export class AppCore {
|
export class AppCore {
|
||||||
@@ -42,6 +44,10 @@ export class AppCore {
|
|||||||
await settingsManager.waitForInitialization();
|
await settingsManager.waitForInitialization();
|
||||||
console.log('AppCore: Settings initialized');
|
console.log('AppCore: Settings initialized');
|
||||||
|
|
||||||
|
// Initialize dynamic base models (async, non-blocking)
|
||||||
|
console.log('AppCore: Initializing dynamic base models...');
|
||||||
|
this.initializeDynamicBaseModels();
|
||||||
|
|
||||||
// Initialize managers
|
// Initialize managers
|
||||||
state.loadingManager = new LoadingManager();
|
state.loadingManager = new LoadingManager();
|
||||||
modalManager.initialize();
|
modalManager.initialize();
|
||||||
@@ -116,6 +122,21 @@ export class AppCore {
|
|||||||
window.globalContextMenuInstance = createGlobalContextMenu();
|
window.globalContextMenuInstance = createGlobalContextMenu();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize dynamic base models from Civitai API
|
||||||
|
// This is non-blocking - runs in background
|
||||||
|
async initializeDynamicBaseModels() {
|
||||||
|
try {
|
||||||
|
const result = await civitaiBaseModelApi.getBaseModels();
|
||||||
|
if (result && result.models) {
|
||||||
|
setDynamicBaseModels(result.models, result.last_updated);
|
||||||
|
console.log(`AppCore: Loaded ${result.merged_count} base models (${result.hardcoded_count} hardcoded + ${result.remote_count} remote)`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn('AppCore: Failed to load dynamic base models:', error);
|
||||||
|
// Non-critical error - app continues with hardcoded models
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and export a singleton instance
|
// Create and export a singleton instance
|
||||||
|
|||||||
@@ -2,7 +2,14 @@ import { modalManager } from './ModalManager.js';
|
|||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast } from '../utils/uiHelpers.js';
|
||||||
import { state, createDefaultSettings } from '../state/index.js';
|
import { state, createDefaultSettings } from '../state/index.js';
|
||||||
import { resetAndReload } from '../api/modelApiFactory.js';
|
import { resetAndReload } from '../api/modelApiFactory.js';
|
||||||
import { DOWNLOAD_PATH_TEMPLATES, MAPPABLE_BASE_MODELS, PATH_TEMPLATE_PLACEHOLDERS, DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/constants.js';
|
import {
|
||||||
|
DOWNLOAD_PATH_TEMPLATES,
|
||||||
|
MAPPABLE_BASE_MODELS,
|
||||||
|
PATH_TEMPLATE_PLACEHOLDERS,
|
||||||
|
DEFAULT_PATH_TEMPLATES,
|
||||||
|
DEFAULT_PRIORITY_TAG_CONFIG,
|
||||||
|
getMappableBaseModelsDynamic
|
||||||
|
} from '../utils/constants.js';
|
||||||
import { translate } from '../utils/i18nHelpers.js';
|
import { translate } from '../utils/i18nHelpers.js';
|
||||||
import { i18n } from '../i18n/index.js';
|
import { i18n } from '../i18n/index.js';
|
||||||
import { configureModelCardVideo } from '../components/shared/ModelCard.js';
|
import { configureModelCardVideo } from '../components/shared/ModelCard.js';
|
||||||
@@ -184,7 +191,9 @@ export class SettingsManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
getAvailableDownloadSkipBaseModels() {
|
getAvailableDownloadSkipBaseModels() {
|
||||||
return MAPPABLE_BASE_MODELS.filter(model => model !== 'Other');
|
// Use dynamic base models if available, fallback to hardcoded
|
||||||
|
const models = getMappableBaseModelsDynamic();
|
||||||
|
return models.filter(model => model !== 'Other');
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizeDownloadSkipBaseModels(value) {
|
normalizeDownloadSkipBaseModels(value) {
|
||||||
@@ -1517,7 +1526,7 @@ export class SettingsManager {
|
|||||||
const row = document.createElement('div');
|
const row = document.createElement('div');
|
||||||
row.className = 'mapping-row';
|
row.className = 'mapping-row';
|
||||||
|
|
||||||
const availableModels = MAPPABLE_BASE_MODELS.filter(model => {
|
const availableModels = getMappableBaseModelsDynamic().filter(model => {
|
||||||
const existingMappings = state.global.settings.base_model_path_mappings || {};
|
const existingMappings = state.global.settings.base_model_path_mappings || {};
|
||||||
return !existingMappings.hasOwnProperty(model) || model === baseModel;
|
return !existingMappings.hasOwnProperty(model) || model === baseModel;
|
||||||
});
|
});
|
||||||
@@ -1619,7 +1628,7 @@ export class SettingsManager {
|
|||||||
const currentValue = select.value;
|
const currentValue = select.value;
|
||||||
|
|
||||||
// Get available models (not already mapped, except current)
|
// Get available models (not already mapped, except current)
|
||||||
const availableModels = MAPPABLE_BASE_MODELS.filter(model =>
|
const availableModels = getMappableBaseModelsDynamic().filter(model =>
|
||||||
!existingMappings.hasOwnProperty(model) || model === currentValue
|
!existingMappings.hasOwnProperty(model) || model === currentValue
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ export const BASE_MODELS = {
|
|||||||
SVD: "SVD",
|
SVD: "SVD",
|
||||||
LTXV: "LTXV",
|
LTXV: "LTXV",
|
||||||
LTXV2: "LTXV2",
|
LTXV2: "LTXV2",
|
||||||
|
LTXV_2_3: "LTXV 2.3",
|
||||||
|
COGVIDE_X: "CogVideoX",
|
||||||
|
MOCHI: "Mochi",
|
||||||
WAN_VIDEO: "Wan Video",
|
WAN_VIDEO: "Wan Video",
|
||||||
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
|
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
|
||||||
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
|
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
|
||||||
@@ -58,7 +61,12 @@ export const BASE_MODELS = {
|
|||||||
WAN_VIDEO_2_2_TI2V_5B: "Wan Video 2.2 TI2V-5B",
|
WAN_VIDEO_2_2_TI2V_5B: "Wan Video 2.2 TI2V-5B",
|
||||||
WAN_VIDEO_2_2_T2V_A14B: "Wan Video 2.2 T2V-A14B",
|
WAN_VIDEO_2_2_T2V_A14B: "Wan Video 2.2 T2V-A14B",
|
||||||
WAN_VIDEO_2_2_I2V_A14B: "Wan Video 2.2 I2V-A14B",
|
WAN_VIDEO_2_2_I2V_A14B: "Wan Video 2.2 I2V-A14B",
|
||||||
|
WAN_VIDEO_2_5_T2V: "Wan Video 2.5 T2V",
|
||||||
|
WAN_VIDEO_2_5_I2V: "Wan Video 2.5 I2V",
|
||||||
HUNYUAN_VIDEO: "Hunyuan Video",
|
HUNYUAN_VIDEO: "Hunyuan Video",
|
||||||
|
// Other models
|
||||||
|
ANIMA: "Anima",
|
||||||
|
PONY_V7: "Pony V7",
|
||||||
// Default
|
// Default
|
||||||
UNKNOWN: "Other"
|
UNKNOWN: "Other"
|
||||||
};
|
};
|
||||||
@@ -151,6 +159,9 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
|||||||
[BASE_MODELS.SVD]: 'SVD',
|
[BASE_MODELS.SVD]: 'SVD',
|
||||||
[BASE_MODELS.LTXV]: 'LTXV',
|
[BASE_MODELS.LTXV]: 'LTXV',
|
||||||
[BASE_MODELS.LTXV2]: 'LTV2',
|
[BASE_MODELS.LTXV2]: 'LTV2',
|
||||||
|
[BASE_MODELS.LTXV_2_3]: 'LTX',
|
||||||
|
[BASE_MODELS.COGVIDE_X]: 'CVX',
|
||||||
|
[BASE_MODELS.MOCHI]: 'MCHI',
|
||||||
[BASE_MODELS.WAN_VIDEO]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO]: 'WAN',
|
||||||
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
|
||||||
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
|
||||||
@@ -159,8 +170,28 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
|||||||
[BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B]: 'WAN',
|
||||||
[BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B]: 'WAN',
|
||||||
[BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B]: 'WAN',
|
||||||
|
[BASE_MODELS.WAN_VIDEO_2_5_T2V]: 'WAN',
|
||||||
|
[BASE_MODELS.WAN_VIDEO_2_5_I2V]: 'WAN',
|
||||||
[BASE_MODELS.HUNYUAN_VIDEO]: 'HYV',
|
[BASE_MODELS.HUNYUAN_VIDEO]: 'HYV',
|
||||||
|
|
||||||
|
// Other diffusion models
|
||||||
|
[BASE_MODELS.AURAFLOW]: 'AF',
|
||||||
|
[BASE_MODELS.CHROMA]: 'CHR',
|
||||||
|
[BASE_MODELS.PIXART_A]: 'PXA',
|
||||||
|
[BASE_MODELS.PIXART_E]: 'PXE',
|
||||||
|
[BASE_MODELS.HUNYUAN_1]: 'HY',
|
||||||
|
[BASE_MODELS.LUMINA]: 'L',
|
||||||
|
[BASE_MODELS.KOLORS]: 'KLR',
|
||||||
|
[BASE_MODELS.NOOBAI]: 'NAI',
|
||||||
|
[BASE_MODELS.ILLUSTRIOUS]: 'IL',
|
||||||
|
[BASE_MODELS.PONY]: 'PONY',
|
||||||
|
[BASE_MODELS.PONY_V7]: 'PNY7',
|
||||||
|
[BASE_MODELS.HIDREAM]: 'HID',
|
||||||
|
[BASE_MODELS.QWEN]: 'QWEN',
|
||||||
|
[BASE_MODELS.ZIMAGE_TURBO]: 'ZIT',
|
||||||
|
[BASE_MODELS.ZIMAGE_BASE]: 'ZIB',
|
||||||
|
[BASE_MODELS.ANIMA]: 'ANI',
|
||||||
|
|
||||||
// Default
|
// Default
|
||||||
[BASE_MODELS.UNKNOWN]: 'OTH'
|
[BASE_MODELS.UNKNOWN]: 'OTH'
|
||||||
};
|
};
|
||||||
@@ -349,18 +380,20 @@ export const BASE_MODEL_CATEGORIES = {
|
|||||||
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
|
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
|
||||||
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
||||||
'Video Models': [
|
'Video Models': [
|
||||||
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO,
|
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.LTXV_2_3,
|
||||||
BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
|
BASE_MODELS.COGVIDE_X, BASE_MODELS.MOCHI, BASE_MODELS.HUNYUAN_VIDEO,
|
||||||
|
BASE_MODELS.WAN_VIDEO, BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
|
||||||
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
|
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
|
||||||
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
|
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
|
||||||
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B
|
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B, BASE_MODELS.WAN_VIDEO_2_5_T2V,
|
||||||
|
BASE_MODELS.WAN_VIDEO_2_5_I2V
|
||||||
],
|
],
|
||||||
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE],
|
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE],
|
||||||
'Other Models': [
|
'Other Models': [
|
||||||
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.PONY_V7, BASE_MODELS.HIDREAM,
|
||||||
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE,
|
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE,
|
||||||
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
||||||
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.ANIMA,
|
||||||
BASE_MODELS.UNKNOWN
|
BASE_MODELS.UNKNOWN
|
||||||
]
|
]
|
||||||
};
|
};
|
||||||
@@ -378,3 +411,94 @@ export const DEFAULT_PRIORITY_TAG_CONFIG = {
|
|||||||
checkpoint: DEFAULT_PRIORITY_TAG_ENTRIES.join(', '),
|
checkpoint: DEFAULT_PRIORITY_TAG_ENTRIES.join(', '),
|
||||||
embedding: DEFAULT_PRIORITY_TAG_ENTRIES.join(', ')
|
embedding: DEFAULT_PRIORITY_TAG_ENTRIES.join(', ')
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Dynamic Base Model Support
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dynamic base model cache
|
||||||
|
* Stores models fetched from Civitai API
|
||||||
|
*/
|
||||||
|
let dynamicBaseModels = null;
|
||||||
|
let dynamicBaseModelsTimestamp = null;
|
||||||
|
const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set dynamic base models (called after fetching from API)
|
||||||
|
* @param {Array} models - Array of base model names
|
||||||
|
* @param {string} timestamp - ISO timestamp of fetch
|
||||||
|
*/
|
||||||
|
export function setDynamicBaseModels(models, timestamp) {
|
||||||
|
dynamicBaseModels = models;
|
||||||
|
dynamicBaseModelsTimestamp = timestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get dynamic base models
|
||||||
|
* @returns {Object|null} { models, timestamp } or null if not set
|
||||||
|
*/
|
||||||
|
export function getDynamicBaseModels() {
|
||||||
|
if (!dynamicBaseModels) return null;
|
||||||
|
|
||||||
|
// Check if cache is expired
|
||||||
|
if (dynamicBaseModelsTimestamp) {
|
||||||
|
const age = Date.now() - new Date(dynamicBaseModelsTimestamp).getTime();
|
||||||
|
if (age > CACHE_TTL_MS) {
|
||||||
|
dynamicBaseModels = null;
|
||||||
|
dynamicBaseModelsTimestamp = null;
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
models: dynamicBaseModels,
|
||||||
|
timestamp: dynamicBaseModelsTimestamp
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get merged base models (hardcoded + dynamic)
|
||||||
|
* Returns unique sorted list of all available base models
|
||||||
|
* @returns {Array} Sorted array of base model names
|
||||||
|
*/
|
||||||
|
export function getMergedBaseModels() {
|
||||||
|
const hardcoded = Object.values(BASE_MODELS);
|
||||||
|
const dynamic = getDynamicBaseModels();
|
||||||
|
|
||||||
|
if (!dynamic || !dynamic.models) {
|
||||||
|
return hardcoded.sort();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge and deduplicate
|
||||||
|
const merged = new Set([...hardcoded, ...dynamic.models]);
|
||||||
|
return Array.from(merged).sort();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get mappable base models (for UI selection)
|
||||||
|
* Excludes 'Other' value
|
||||||
|
* @returns {Array} Sorted array of base model names (excluding 'Other')
|
||||||
|
*/
|
||||||
|
export function getMappableBaseModelsDynamic() {
|
||||||
|
const merged = getMergedBaseModels();
|
||||||
|
return merged.filter(model => model !== 'Other');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear dynamic base models cache
|
||||||
|
*/
|
||||||
|
export function clearDynamicBaseModels() {
|
||||||
|
dynamicBaseModels = null;
|
||||||
|
dynamicBaseModelsTimestamp = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if dynamic base models cache is valid
|
||||||
|
* @returns {boolean}
|
||||||
|
*/
|
||||||
|
export function isDynamicBaseModelsCacheValid() {
|
||||||
|
if (!dynamicBaseModels || !dynamicBaseModelsTimestamp) return false;
|
||||||
|
const age = Date.now() - new Date(dynamicBaseModelsTimestamp).getTime();
|
||||||
|
return age <= CACHE_TTL_MS;
|
||||||
|
}
|
||||||
|
|||||||
133
tests/services/test_civitai_base_model_service.py
Normal file
133
tests/services/test_civitai_base_model_service.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for CivitaiBaseModelService."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from py.services.civitai_base_model_service import CivitaiBaseModelService
|
||||||
|
|
||||||
|
|
||||||
|
class TestCivitaiBaseModelService:
|
||||||
|
"""Test suite for CivitaiBaseModelService."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_service(self):
|
||||||
|
"""Create a fresh service instance for each test."""
|
||||||
|
self.service = CivitaiBaseModelService()
|
||||||
|
# Reset cache
|
||||||
|
self.service._cache = None
|
||||||
|
self.service._cache_timestamp = None
|
||||||
|
yield
|
||||||
|
|
||||||
|
def test_generate_abbreviation_known_models(self):
|
||||||
|
"""Test abbreviation generation for known models."""
|
||||||
|
test_cases = [
|
||||||
|
("SD 1.5", "SD1"),
|
||||||
|
("SDXL 1.0", "XL"),
|
||||||
|
("Flux.1 D", "F1D"),
|
||||||
|
("Wan Video 2.5 T2V", "WAN"),
|
||||||
|
("Pony V7", "PNY7"),
|
||||||
|
("CogVideoX", "CVX"),
|
||||||
|
("Mochi", "MCHI"),
|
||||||
|
("Anima", "ANI"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for model_name, expected in test_cases:
|
||||||
|
result = self.service.generate_abbreviation(model_name)
|
||||||
|
assert result == expected, (
|
||||||
|
f"Failed for {model_name}: got {result}, expected {expected}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_generate_abbreviation_unknown_models(self):
|
||||||
|
"""Test abbreviation generation for unknown models."""
|
||||||
|
result = self.service.generate_abbreviation("New Model 2.0")
|
||||||
|
assert len(result) <= 4
|
||||||
|
assert result.isupper()
|
||||||
|
|
||||||
|
def test_generate_abbreviation_edge_cases(self):
|
||||||
|
"""Test abbreviation generation edge cases."""
|
||||||
|
assert self.service.generate_abbreviation("") == "OTH"
|
||||||
|
assert self.service.generate_abbreviation(None) == "OTH"
|
||||||
|
|
||||||
|
def test_cache_status_no_cache(self):
|
||||||
|
"""Test cache status when no cache exists."""
|
||||||
|
status = self.service.get_cache_status()
|
||||||
|
|
||||||
|
assert status["has_cache"] is False
|
||||||
|
assert status["last_updated"] is None
|
||||||
|
assert status["is_expired"] is True
|
||||||
|
assert status["age_seconds"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_base_models_fallback(self):
|
||||||
|
"""Test that fallback to hardcoded models works."""
|
||||||
|
with patch.object(self.service, "_fetch_from_civitai", return_value=None):
|
||||||
|
result = await self.service.get_base_models()
|
||||||
|
|
||||||
|
assert result["source"] == "fallback"
|
||||||
|
assert len(result["models"]) > 0
|
||||||
|
assert result["hardcoded_count"] > 0
|
||||||
|
assert result["remote_count"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_base_models_from_api(self):
|
||||||
|
"""Test fetching models from API."""
|
||||||
|
mock_models = {"SD 1.5", "SDXL 1.0", "New Model"}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
self.service, "_fetch_from_civitai", return_value=mock_models
|
||||||
|
):
|
||||||
|
result = await self.service.get_base_models()
|
||||||
|
|
||||||
|
assert result["source"] == "api"
|
||||||
|
assert result["remote_count"] == 3
|
||||||
|
assert "New Model" in result["models"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_base_models_uses_cache(self):
|
||||||
|
"""Test that cached data is used when available and not expired."""
|
||||||
|
# First call - populate cache
|
||||||
|
mock_models = {"SD 1.5", "SDXL 1.0"}
|
||||||
|
with patch.object(
|
||||||
|
self.service, "_fetch_from_civitai", return_value=mock_models
|
||||||
|
):
|
||||||
|
await self.service.get_base_models()
|
||||||
|
|
||||||
|
# Second call - should use cache
|
||||||
|
with patch.object(self.service, "_fetch_from_civitai") as mock_fetch:
|
||||||
|
result = await self.service.get_base_models()
|
||||||
|
mock_fetch.assert_not_called()
|
||||||
|
|
||||||
|
assert result["source"] == "cache"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_cache(self):
|
||||||
|
"""Test force refresh clears cache and fetches fresh data."""
|
||||||
|
# Populate cache
|
||||||
|
mock_models = {"SD 1.5"}
|
||||||
|
with patch.object(
|
||||||
|
self.service, "_fetch_from_civitai", return_value=mock_models
|
||||||
|
):
|
||||||
|
await self.service.get_base_models()
|
||||||
|
|
||||||
|
# Force refresh with different data
|
||||||
|
new_models = {"SD 1.5", "SDXL 1.0", "New Model"}
|
||||||
|
with patch.object(self.service, "_fetch_from_civitai", return_value=new_models):
|
||||||
|
result = await self.service.refresh_cache()
|
||||||
|
|
||||||
|
assert result["source"] == "api"
|
||||||
|
assert result["remote_count"] == 3
|
||||||
|
|
||||||
|
def test_get_model_categories(self):
|
||||||
|
"""Test model categories are returned."""
|
||||||
|
categories = self.service.get_model_categories()
|
||||||
|
|
||||||
|
assert "Stable Diffusion 1.x" in categories
|
||||||
|
assert "Video Models" in categories
|
||||||
|
assert "Flux Models" in categories
|
||||||
|
assert "Other Models" in categories
|
||||||
|
|
||||||
|
# Check that video models include new additions
|
||||||
|
video_models = categories["Video Models"]
|
||||||
|
assert "CogVideoX" in video_models
|
||||||
|
assert "Mochi" in video_models
|
||||||
|
assert "Wan Video 2.5 T2V" in video_models
|
||||||
Reference in New Issue
Block a user