feat: Dynamic base model fetching from Civitai API (#854)

Implement automatic fetching of base models from Civitai API to keep
data up-to-date without manual updates.

Backend:
- Add CivitaiBaseModelService with 7-day TTL caching
- Add /api/lm/base-models endpoints for fetching and refreshing
- Merge hardcoded and remote models for backward compatibility
- Smart abbreviation generation for unknown models

Frontend:
- Add civitaiBaseModelApi client for API communication
- Dynamic base model loading on app initialization
- Update SettingsManager to use merged model lists
- Add support for 8 new models: Anima, CogVideoX, LTXV 2.3, Mochi,
  Pony V7, Wan Video 2.5 T2V/I2V

API Endpoints:
- GET /api/lm/base-models - Get merged models
- POST /api/lm/base-models/refresh - Force refresh
- GET /api/lm/base-models/categories - Get categories
- GET /api/lm/base-models/cache-status - Check cache status

Closes #854
This commit is contained in:
Will Miao
2026-03-29 00:18:15 +08:00
parent 89b1675ec7
commit 00f5c1e887
12 changed files with 1227 additions and 9 deletions

View File

@@ -0,0 +1,141 @@
"""Handlers for base model related endpoints."""
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable, Dict
from aiohttp import web
from ...services.civitai_base_model_service import get_civitai_base_model_service
logger = logging.getLogger(__name__)
class BaseModelHandlerSet:
"""Collection of handlers for base model operations."""
def __init__(
self,
base_model_service_factory: Callable[[], Any] = get_civitai_base_model_service,
) -> None:
self._base_model_service_factory = base_model_service_factory
def to_route_mapping(
self,
) -> Dict[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
"""Return mapping of route names to handler methods."""
return {
"get_base_models": self.get_base_models,
"refresh_base_models": self.refresh_base_models,
"get_base_model_categories": self.get_base_model_categories,
"get_base_model_cache_status": self.get_base_model_cache_status,
}
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get merged base models (hardcoded + remote from Civitai).
Query Parameters:
refresh: If 'true', force refresh from API
Returns:
JSON response with:
- models: List of base model names
- source: 'cache', 'api', or 'fallback'
- last_updated: ISO timestamp
- counts: hardcoded_count, remote_count, merged_count
"""
try:
service = await self._base_model_service_factory()
# Check for refresh parameter
force_refresh = request.query.get("refresh", "").lower() == "true"
result = await service.get_base_models(force_refresh=force_refresh)
return web.json_response(
{
"success": True,
"data": result,
}
)
except Exception as e:
logger.error(f"Error in get_base_models: {e}")
return web.json_response(
{"success": False, "error": str(e)},
status=500,
)
async def refresh_base_models(self, request: web.Request) -> web.Response:
"""Force refresh base models from Civitai API.
Returns:
JSON response with refreshed data
"""
try:
service = await self._base_model_service_factory()
result = await service.refresh_cache()
return web.json_response(
{
"success": True,
"data": result,
"message": "Base models cache refreshed successfully",
}
)
except Exception as e:
logger.error(f"Error in refresh_base_models: {e}")
return web.json_response(
{"success": False, "error": str(e)},
status=500,
)
async def get_base_model_categories(self, request: web.Request) -> web.Response:
"""Get categorized base models.
Returns:
JSON response with categorized models
"""
try:
service = await self._base_model_service_factory()
categories = service.get_model_categories()
return web.json_response(
{
"success": True,
"data": categories,
}
)
except Exception as e:
logger.error(f"Error in get_base_model_categories: {e}")
return web.json_response(
{"success": False, "error": str(e)},
status=500,
)
async def get_base_model_cache_status(self, request: web.Request) -> web.Response:
"""Get cache status for base models.
Returns:
JSON response with cache status
"""
try:
service = await self._base_model_service_factory()
status = service.get_cache_status()
return web.json_response(
{
"success": True,
"data": status,
}
)
except Exception as e:
logger.error(f"Error in get_base_model_cache_status: {e}")
return web.json_response(
{"success": False, "error": str(e)},
status=500,
)

View File

@@ -40,6 +40,7 @@ from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import is_valid_example_images_root from ...utils.example_images_paths import is_valid_example_images_root
from ...utils.lora_metadata import extract_trained_words from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats from ...utils.usage_stats import UsageStats
from .base_model_handlers import BaseModelHandlerSet
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -1618,6 +1619,7 @@ class MiscHandlerSet:
custom_words: CustomWordsHandler, custom_words: CustomWordsHandler,
supporters: SupportersHandler, supporters: SupportersHandler,
example_workflows: ExampleWorkflowsHandler, example_workflows: ExampleWorkflowsHandler,
base_model: BaseModelHandlerSet,
) -> None: ) -> None:
self.health = health self.health = health
self.settings = settings self.settings = settings
@@ -1632,6 +1634,7 @@ class MiscHandlerSet:
self.custom_words = custom_words self.custom_words = custom_words
self.supporters = supporters self.supporters = supporters
self.example_workflows = example_workflows self.example_workflows = example_workflows
self.base_model = base_model
def to_route_mapping( def to_route_mapping(
self, self,
@@ -1663,6 +1666,11 @@ class MiscHandlerSet:
"get_supporters": self.supporters.get_supporters, "get_supporters": self.supporters.get_supporters,
"get_example_workflows": self.example_workflows.get_example_workflows, "get_example_workflows": self.example_workflows.get_example_workflows,
"get_example_workflow": self.example_workflows.get_example_workflow, "get_example_workflow": self.example_workflows.get_example_workflow,
# Base model handlers
"get_base_models": self.base_model.get_base_models,
"refresh_base_models": self.base_model.refresh_base_models,
"get_base_model_categories": self.base_model.get_base_model_categories,
"get_base_model_cache_status": self.base_model.get_base_model_cache_status,
} }

View File

@@ -56,6 +56,15 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition( RouteDefinition(
"GET", "/api/lm/example-workflows/{filename}", "get_example_workflow" "GET", "/api/lm/example-workflows/{filename}", "get_example_workflow"
), ),
# Base model management routes
RouteDefinition("GET", "/api/lm/base-models", "get_base_models"),
RouteDefinition("POST", "/api/lm/base-models/refresh", "refresh_base_models"),
RouteDefinition(
"GET", "/api/lm/base-models/categories", "get_base_model_categories"
),
RouteDefinition(
"GET", "/api/lm/base-models/cache-status", "get_base_model_cache_status"
),
) )

View File

@@ -35,6 +35,7 @@ from .handlers.misc_handlers import (
UsageStatsHandler, UsageStatsHandler,
build_service_registry_adapter, build_service_registry_adapter,
) )
from .handlers.base_model_handlers import BaseModelHandlerSet
from .misc_route_registrar import MiscRouteRegistrar from .misc_route_registrar import MiscRouteRegistrar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -128,6 +129,7 @@ class MiscRoutes:
custom_words = CustomWordsHandler() custom_words = CustomWordsHandler()
supporters = SupportersHandler() supporters = SupportersHandler()
example_workflows = ExampleWorkflowsHandler() example_workflows = ExampleWorkflowsHandler()
base_model = BaseModelHandlerSet()
return self._handler_set_factory( return self._handler_set_factory(
health=health, health=health,
@@ -143,6 +145,7 @@ class MiscRoutes:
custom_words=custom_words, custom_words=custom_words,
supporters=supporters, supporters=supporters,
example_workflows=example_workflows, example_workflows=example_workflows,
base_model=base_model,
) )

View File

@@ -0,0 +1,430 @@
from __future__ import annotations
import asyncio
import json
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, Tuple
from ..utils.constants import SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS
from .downloader import get_downloader
logger = logging.getLogger(__name__)
class CivitaiBaseModelService:
"""Service for fetching and managing Civitai base models.
This service provides:
- Fetching base models from Civitai API
- Caching with TTL (7 days default)
- Merging hardcoded and remote base models
- Generating abbreviations for new/unknown models
"""
_instance: Optional[CivitaiBaseModelService] = None
_lock = asyncio.Lock()
# Default TTL for cache in seconds (7 days)
DEFAULT_CACHE_TTL = 7 * 24 * 60 * 60
# Civitai API endpoint for enums
CIVITAI_ENUMS_URL = "https://civitai.com/api/v1/enums"
@classmethod
async def get_instance(cls) -> CivitaiBaseModelService:
"""Get singleton instance of the service."""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize the service."""
if hasattr(self, "_initialized"):
return
self._initialized = True
# Cache storage
self._cache: Optional[Dict[str, Any]] = None
self._cache_timestamp: Optional[datetime] = None
self._cache_ttl = self.DEFAULT_CACHE_TTL
# Hardcoded models for fallback
self._hardcoded_models = set(SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS)
logger.info("CivitaiBaseModelService initialized")
async def get_base_models(self, force_refresh: bool = False) -> Dict[str, Any]:
"""Get merged base models (hardcoded + remote).
Args:
force_refresh: If True, fetch from API regardless of cache state.
Returns:
Dictionary containing:
- models: List of merged base model names
- source: 'cache', 'api', or 'fallback'
- last_updated: ISO timestamp of last successful API fetch
- hardcoded_count: Number of hardcoded models
- remote_count: Number of remote models
- merged_count: Total unique models
"""
# Check if cache is valid
if not force_refresh and self._is_cache_valid():
logger.debug("Returning cached base models")
return self._build_response("cache")
# Try to fetch from API
try:
remote_models = await self._fetch_from_civitai()
if remote_models:
self._update_cache(remote_models)
return self._build_response("api")
except Exception as e:
logger.error(f"Failed to fetch base models from Civitai: {e}")
# Fallback to hardcoded models
return self._build_response("fallback")
async def refresh_cache(self) -> Dict[str, Any]:
"""Force refresh the cache from Civitai API.
Returns:
Response dict same as get_base_models()
"""
return await self.get_base_models(force_refresh=True)
def get_cache_status(self) -> Dict[str, Any]:
"""Get current cache status.
Returns:
Dictionary containing:
- has_cache: Whether cache exists
- last_updated: ISO timestamp or None
- is_expired: Whether cache is expired
- ttl_seconds: TTL in seconds
- age_seconds: Age of cache in seconds (if exists)
"""
if self._cache is None or self._cache_timestamp is None:
return {
"has_cache": False,
"last_updated": None,
"is_expired": True,
"ttl_seconds": self._cache_ttl,
"age_seconds": None,
}
age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds()
return {
"has_cache": True,
"last_updated": self._cache_timestamp.isoformat(),
"is_expired": age > self._cache_ttl,
"ttl_seconds": self._cache_ttl,
"age_seconds": int(age),
}
def generate_abbreviation(self, model_name: str) -> str:
"""Generate abbreviation for a base model name.
Algorithm:
1. Extract version patterns (e.g., "2.5" from "Wan Video 2.5")
2. Extract main acronym (e.g., "SD" from "SD 1.5")
3. Handle special cases (Flux, Wan, etc.)
4. Fallback to first letters of words (max 4 chars)
Args:
model_name: Full base model name
Returns:
Generated abbreviation (max 4 characters)
"""
if not model_name or not isinstance(model_name, str):
return "OTH"
name = model_name.strip()
if not name:
return "OTH"
# Check if it's already in hardcoded abbreviations
# This is a simplified check - in practice you'd have a mapping
lower_name = name.lower()
# Special cases
special_cases = {
"sd 1.4": "SD1",
"sd 1.5": "SD1",
"sd 1.5 lcm": "SD1",
"sd 1.5 hyper": "SD1",
"sd 2.0": "SD2",
"sd 2.1": "SD2",
"sd 3": "SD3",
"sd 3.5": "SD3",
"sd 3.5 medium": "SD3",
"sd 3.5 large": "SD3",
"sd 3.5 large turbo": "SD3",
"sdxl 1.0": "XL",
"sdxl lightning": "XL",
"sdxl hyper": "XL",
"flux.1 d": "F1D",
"flux.1 s": "F1S",
"flux.1 krea": "F1KR",
"flux.1 kontext": "F1KX",
"flux.2 d": "F2D",
"flux.2 klein 9b": "FK9",
"flux.2 klein 9b-base": "FK9B",
"flux.2 klein 4b": "FK4",
"flux.2 klein 4b-base": "FK4B",
"auraflow": "AF",
"chroma": "CHR",
"pixart a": "PXA",
"pixart e": "PXE",
"hunyuan 1": "HY",
"hunyuan video": "HYV",
"lumina": "L",
"kolors": "KLR",
"noobai": "NAI",
"illustrious": "IL",
"pony": "PONY",
"pony v7": "PNY7",
"hidream": "HID",
"qwen": "QWEN",
"zimageturbo": "ZIT",
"zimagebase": "ZIB",
"anima": "ANI",
"svd": "SVD",
"ltxv": "LTXV",
"ltxv2": "LTV2",
"ltxv 2.3": "LTX",
"cogvideox": "CVX",
"mochi": "MCHI",
"wan video": "WAN",
"wan video 1.3b t2v": "WAN",
"wan video 14b t2v": "WAN",
"wan video 14b i2v 480p": "WAN",
"wan video 14b i2v 720p": "WAN",
"wan video 2.2 ti2v-5b": "WAN",
"wan video 2.2 t2v-a14b": "WAN",
"wan video 2.2 i2v-a14b": "WAN",
"wan video 2.5 t2v": "WAN",
"wan video 2.5 i2v": "WAN",
}
if lower_name in special_cases:
return special_cases[lower_name]
# Try to extract acronym from version pattern
# e.g., "Model Name 2.5" -> "MN25"
version_match = re.search(r"(\d+(?:\.\d+)?)", name)
version = version_match.group(1) if version_match else ""
# Remove version and common words
words = re.sub(r"\d+(?:\.\d+)?", "", name)
words = re.sub(
r"\b(model|video|diffusion|checkpoint|textualinversion)\b",
"",
words,
flags=re.I,
)
words = words.strip()
# Get first letters of remaining words
tokens = re.findall(r"[A-Za-z]+", words)
if tokens:
# Build abbreviation from first letters
abbrev = "".join(token[0].upper() for token in tokens)
# Add version if present
if version:
# Clean version (remove dots for abbreviation)
version_clean = version.replace(".", "")
abbrev = abbrev[: 4 - len(version_clean)] + version_clean
return abbrev[:4]
# Final fallback: just take first 4 alphanumeric chars
alphanumeric = re.sub(r"[^A-Za-z0-9]", "", name)
if alphanumeric:
return alphanumeric[:4].upper()
return "OTH"
async def _fetch_from_civitai(self) -> Optional[Set[str]]:
"""Fetch base models from Civitai API.
Returns:
Set of base model names, or None if failed
"""
try:
downloader = await get_downloader()
success, result = await downloader.make_request(
"GET",
self.CIVITAI_ENUMS_URL,
use_auth=False, # enums endpoint doesn't require auth
)
if not success:
logger.warning(f"Failed to fetch enums from Civitai: {result}")
return None
if isinstance(result, str):
data = json.loads(result)
else:
data = result
# Extract base models from response
base_models = set()
# Use ActiveBaseModel if available (recommended active models)
if "ActiveBaseModel" in data:
base_models.update(data["ActiveBaseModel"])
logger.info(f"Fetched {len(base_models)} models from ActiveBaseModel")
# Fallback to full BaseModel list
elif "BaseModel" in data:
base_models.update(data["BaseModel"])
logger.info(f"Fetched {len(base_models)} models from BaseModel")
else:
logger.warning("No base model data found in Civitai response")
return None
return base_models
except Exception as e:
logger.error(f"Error fetching from Civitai: {e}")
return None
def _update_cache(self, remote_models: Set[str]) -> None:
"""Update internal cache with fetched models.
Args:
remote_models: Set of base model names from API
"""
self._cache = {
"remote_models": sorted(remote_models),
"hardcoded_models": sorted(self._hardcoded_models),
}
self._cache_timestamp = datetime.now(timezone.utc)
logger.info(f"Cache updated with {len(remote_models)} remote models")
def _is_cache_valid(self) -> bool:
"""Check if current cache is valid (not expired).
Returns:
True if cache exists and is not expired
"""
if self._cache is None or self._cache_timestamp is None:
return False
age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds()
return age <= self._cache_ttl
def _build_response(self, source: str) -> Dict[str, Any]:
"""Build response dictionary.
Args:
source: 'cache', 'api', or 'fallback'
Returns:
Response dictionary
"""
if source == "fallback" or self._cache is None:
# Use only hardcoded models
merged = sorted(self._hardcoded_models)
return {
"models": merged,
"source": source,
"last_updated": None,
"hardcoded_count": len(self._hardcoded_models),
"remote_count": 0,
"merged_count": len(merged),
}
# Merge hardcoded and remote models
remote_set = set(self._cache.get("remote_models", []))
merged = sorted(self._hardcoded_models | remote_set)
return {
"models": merged,
"source": source,
"last_updated": self._cache_timestamp.isoformat()
if self._cache_timestamp
else None,
"hardcoded_count": len(self._hardcoded_models),
"remote_count": len(remote_set),
"merged_count": len(merged),
}
def get_model_categories(self) -> Dict[str, List[str]]:
"""Get categorized base models.
Returns:
Dictionary mapping category names to lists of model names
"""
# Define category patterns
categories = {
"Stable Diffusion 1.x": ["SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper"],
"Stable Diffusion 2.x": ["SD 2.0", "SD 2.1"],
"Stable Diffusion 3.x": [
"SD 3",
"SD 3.5",
"SD 3.5 Medium",
"SD 3.5 Large",
"SD 3.5 Large Turbo",
],
"SDXL": ["SDXL 1.0", "SDXL Lightning", "SDXL Hyper"],
"Flux Models": [
"Flux.1 D",
"Flux.1 S",
"Flux.1 Krea",
"Flux.1 Kontext",
"Flux.2 D",
"Flux.2 Klein 9B",
"Flux.2 Klein 9B-base",
"Flux.2 Klein 4B",
"Flux.2 Klein 4B-base",
],
"Video Models": [
"SVD",
"LTXV",
"LTXV2",
"LTXV 2.3",
"CogVideoX",
"Mochi",
"Hunyuan Video",
"Wan Video",
"Wan Video 1.3B t2v",
"Wan Video 14B t2v",
"Wan Video 14B i2v 480p",
"Wan Video 14B i2v 720p",
"Wan Video 2.2 TI2V-5B",
"Wan Video 2.2 T2V-A14B",
"Wan Video 2.2 I2V-A14B",
"Wan Video 2.5 T2V",
"Wan Video 2.5 I2V",
],
"Other Models": [
"Illustrious",
"Pony",
"Pony V7",
"HiDream",
"Qwen",
"AuraFlow",
"Chroma",
"ZImageTurbo",
"ZImageBase",
"PixArt a",
"PixArt E",
"Hunyuan 1",
"Lumina",
"Kolors",
"NoobAI",
"Anima",
],
}
return categories
# Convenience function for getting the singleton instance
async def get_civitai_base_model_service() -> CivitaiBaseModelService:
"""Get the singleton instance of CivitaiBaseModelService."""
return await CivitaiBaseModelService.get_instance()

View File

@@ -110,6 +110,8 @@ DIFFUSION_MODEL_BASE_MODELS = frozenset(
"Wan Video 2.2 T2V-A14B", "Wan Video 2.2 T2V-A14B",
"Wan Video 2.5 T2V", "Wan Video 2.5 T2V",
"Wan Video 2.5 I2V", "Wan Video 2.5 I2V",
"CogVideoX",
"Mochi",
"Qwen", "Qwen",
] ]
) )
@@ -151,6 +153,7 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
"NoobAI", "NoobAI",
"Illustrious", "Illustrious",
"Pony", "Pony",
"Pony V7",
"HiDream", "HiDream",
"Qwen", "Qwen",
"ZImageTurbo", "ZImageTurbo",
@@ -158,6 +161,9 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
"SVD", "SVD",
"LTXV", "LTXV",
"LTXV2", "LTXV2",
"LTXV 2.3",
"CogVideoX",
"Mochi",
"Wan Video", "Wan Video",
"Wan Video 1.3B t2v", "Wan Video 1.3B t2v",
"Wan Video 14B t2v", "Wan Video 14B t2v",
@@ -166,6 +172,9 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
"Wan Video 2.2 TI2V-5B", "Wan Video 2.2 TI2V-5B",
"Wan Video 2.2 T2V-A14B", "Wan Video 2.2 T2V-A14B",
"Wan Video 2.2 I2V-A14B", "Wan Video 2.2 I2V-A14B",
"Wan Video 2.5 T2V",
"Wan Video 2.5 I2V",
"Hunyuan Video", "Hunyuan Video",
"Anima",
] ]
) )

167
refs/enums.json Normal file
View File

@@ -0,0 +1,167 @@
{
"ModelType": [
"Checkpoint",
"TextualInversion",
"Hypernetwork",
"AestheticGradient",
"LORA",
"LoCon",
"DoRA",
"Controlnet",
"Upscaler",
"MotionModule",
"VAE",
"Poses",
"Wildcards",
"Workflows",
"Detection",
"Other"
],
"ModelFileType": [
"Model",
"Text Encoder",
"Pruned Model",
"Negative",
"Training Data",
"VAE",
"Config",
"Archive"
],
"ActiveBaseModel": [
"Anima",
"AuraFlow",
"Chroma",
"CogVideoX",
"Flux.1 S",
"Flux.1 D",
"Flux.1 Krea",
"Flux.1 Kontext",
"Flux.2 D",
"Flux.2 Klein 9B",
"Flux.2 Klein 9B-base",
"Flux.2 Klein 4B",
"Flux.2 Klein 4B-base",
"HiDream",
"Hunyuan 1",
"Hunyuan Video",
"Illustrious",
"Kolors",
"LTXV",
"LTXV2",
"LTXV 2.3",
"Lumina",
"Mochi",
"NoobAI",
"Other",
"PixArt a",
"PixArt E",
"Pony",
"Pony V7",
"Qwen",
"SD 1.4",
"SD 1.5",
"SD 1.5 LCM",
"SD 1.5 Hyper",
"SD 2.0",
"SD 2.1",
"SDXL 1.0",
"SDXL Lightning",
"SDXL Hyper",
"Wan Video 1.3B t2v",
"Wan Video 14B t2v",
"Wan Video 14B i2v 480p",
"Wan Video 14B i2v 720p",
"Wan Video 2.2 TI2V-5B",
"Wan Video 2.2 I2V-A14B",
"Wan Video 2.2 T2V-A14B",
"Wan Video 2.5 T2V",
"Wan Video 2.5 I2V",
"ZImageTurbo",
"ZImageBase"
],
"BaseModel": [
"Anima",
"AuraFlow",
"Chroma",
"CogVideoX",
"Flux.1 S",
"Flux.1 D",
"Flux.1 Krea",
"Flux.1 Kontext",
"Flux.2 D",
"Flux.2 Klein 9B",
"Flux.2 Klein 9B-base",
"Flux.2 Klein 4B",
"Flux.2 Klein 4B-base",
"HiDream",
"Hunyuan 1",
"Hunyuan Video",
"Illustrious",
"Imagen4",
"Kling",
"Kolors",
"LTXV",
"LTXV2",
"LTXV 2.3",
"Lumina",
"Mochi",
"Nano Banana",
"NoobAI",
"ODOR",
"OpenAI",
"Other",
"PixArt a",
"PixArt E",
"Playground v2",
"Pony",
"Pony V7",
"Qwen",
"Stable Cascade",
"SD 1.4",
"SD 1.5",
"SD 1.5 LCM",
"SD 1.5 Hyper",
"SD 2.0",
"SD 2.0 768",
"SD 2.1",
"SD 2.1 768",
"SD 2.1 Unclip",
"SD 3",
"SD 3.5",
"SD 3.5 Large",
"SD 3.5 Large Turbo",
"SD 3.5 Medium",
"Sora 2",
"SDXL 0.9",
"SDXL 1.0",
"SDXL 1.0 LCM",
"SDXL Lightning",
"SDXL Hyper",
"SDXL Turbo",
"SDXL Distilled",
"Seedance",
"Seedream",
"SVD",
"SVD XT",
"Veo 3",
"Vidu Q1",
"Wan Video",
"Wan Video 1.3B t2v",
"Wan Video 14B t2v",
"Wan Video 14B i2v 480p",
"Wan Video 14B i2v 720p",
"Wan Video 2.2 TI2V-5B",
"Wan Video 2.2 I2V-A14B",
"Wan Video 2.2 T2V-A14B",
"Wan Video 2.5 T2V",
"Wan Video 2.5 I2V",
"ZImageTurbo",
"ZImageBase"
],
"BaseModelType": [
"Standard",
"Inpainting",
"Refiner",
"Pix2Pix"
]
}

View File

@@ -0,0 +1,164 @@
/**
* API client for Civitai base model management
* Handles fetching and refreshing base models from Civitai API
*/
import { showToast } from '../utils/uiHelpers.js';
const BASE_MODEL_ENDPOINTS = {
getModels: '/api/lm/base-models',
refresh: '/api/lm/base-models/refresh',
categories: '/api/lm/base-models/categories',
cacheStatus: '/api/lm/base-models/cache-status',
};
/**
* Civitai Base Model API Client
*/
export class CivitaiBaseModelApi {
constructor() {
this.cache = null;
this.cacheTimestamp = null;
}
/**
* Get base models (with caching)
* @param {boolean} forceRefresh - Force refresh from API
* @returns {Promise<Object>} Response with models, source, and counts
*/
async getBaseModels(forceRefresh = false) {
try {
const url = new URL(BASE_MODEL_ENDPOINTS.getModels, window.location.origin);
if (forceRefresh) {
url.searchParams.append('refresh', 'true');
}
const response = await fetch(url);
if (!response.ok) {
throw new Error(`Failed to fetch base models: ${response.statusText}`);
}
const data = await response.json();
if (data.success) {
this.cache = data.data;
this.cacheTimestamp = Date.now();
return data.data;
} else {
throw new Error(data.error || 'Failed to fetch base models');
}
} catch (error) {
console.error('Error fetching base models:', error);
showToast('Failed to fetch base models', { message: error.message }, 'error');
throw error;
}
}
/**
* Force refresh base models from Civitai API
* @returns {Promise<Object>} Refreshed data
*/
async refreshBaseModels() {
try {
const response = await fetch(BASE_MODEL_ENDPOINTS.refresh, {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
if (!response.ok) {
throw new Error(`Failed to refresh base models: ${response.statusText}`);
}
const data = await response.json();
if (data.success) {
this.cache = data.data;
this.cacheTimestamp = Date.now();
showToast('Base models refreshed successfully', {}, 'success');
return data.data;
} else {
throw new Error(data.error || 'Failed to refresh base models');
}
} catch (error) {
console.error('Error refreshing base models:', error);
showToast('Failed to refresh base models', { message: error.message }, 'error');
throw error;
}
}
/**
* Get base model categories
* @returns {Promise<Object>} Categories with model lists
*/
async getCategories() {
try {
const response = await fetch(BASE_MODEL_ENDPOINTS.categories);
if (!response.ok) {
throw new Error(`Failed to fetch categories: ${response.statusText}`);
}
const data = await response.json();
if (data.success) {
return data.data;
} else {
throw new Error(data.error || 'Failed to fetch categories');
}
} catch (error) {
console.error('Error fetching categories:', error);
throw error;
}
}
/**
* Get cache status
* @returns {Promise<Object>} Cache status information
*/
async getCacheStatus() {
try {
const response = await fetch(BASE_MODEL_ENDPOINTS.cacheStatus);
if (!response.ok) {
throw new Error(`Failed to fetch cache status: ${response.statusText}`);
}
const data = await response.json();
if (data.success) {
return data.data;
} else {
throw new Error(data.error || 'Failed to fetch cache status');
}
} catch (error) {
console.error('Error fetching cache status:', error);
throw error;
}
}
/**
* Get cached models (if available)
* @returns {Object|null} Cached data or null
*/
getCachedModels() {
return this.cache;
}
/**
* Check if cache is available
* @returns {boolean}
*/
hasCache() {
return this.cache !== null;
}
/**
* Get cache age in milliseconds
* @returns {number|null} Age in ms or null if no cache
*/
getCacheAge() {
if (!this.cacheTimestamp) return null;
return Date.now() - this.cacheTimestamp;
}
}
// Export singleton instance
export const civitaiBaseModelApi = new CivitaiBaseModelApi();

View File

@@ -17,6 +17,8 @@ import { onboardingManager } from './managers/OnboardingManager.js';
import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js'; import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js';
import { createPageContextMenu, createGlobalContextMenu } from './components/ContextMenu/index.js'; import { createPageContextMenu, createGlobalContextMenu } from './components/ContextMenu/index.js';
import { initializeEventManagement } from './utils/eventManagementInit.js'; import { initializeEventManagement } from './utils/eventManagementInit.js';
import { civitaiBaseModelApi } from './api/civitaiBaseModelApi.js';
import { setDynamicBaseModels } from './utils/constants.js';
// Core application class // Core application class
export class AppCore { export class AppCore {
@@ -42,6 +44,10 @@ export class AppCore {
await settingsManager.waitForInitialization(); await settingsManager.waitForInitialization();
console.log('AppCore: Settings initialized'); console.log('AppCore: Settings initialized');
// Initialize dynamic base models (async, non-blocking)
console.log('AppCore: Initializing dynamic base models...');
this.initializeDynamicBaseModels();
// Initialize managers // Initialize managers
state.loadingManager = new LoadingManager(); state.loadingManager = new LoadingManager();
modalManager.initialize(); modalManager.initialize();
@@ -116,6 +122,21 @@ export class AppCore {
window.globalContextMenuInstance = createGlobalContextMenu(); window.globalContextMenuInstance = createGlobalContextMenu();
} }
} }
// Initialize dynamic base models from Civitai API
// This is non-blocking - runs in background
async initializeDynamicBaseModels() {
try {
const result = await civitaiBaseModelApi.getBaseModels();
if (result && result.models) {
setDynamicBaseModels(result.models, result.last_updated);
console.log(`AppCore: Loaded ${result.merged_count} base models (${result.hardcoded_count} hardcoded + ${result.remote_count} remote)`);
}
} catch (error) {
console.warn('AppCore: Failed to load dynamic base models:', error);
// Non-critical error - app continues with hardcoded models
}
}
} }
// Create and export a singleton instance // Create and export a singleton instance

View File

@@ -2,7 +2,14 @@ import { modalManager } from './ModalManager.js';
import { showToast } from '../utils/uiHelpers.js'; import { showToast } from '../utils/uiHelpers.js';
import { state, createDefaultSettings } from '../state/index.js'; import { state, createDefaultSettings } from '../state/index.js';
import { resetAndReload } from '../api/modelApiFactory.js'; import { resetAndReload } from '../api/modelApiFactory.js';
import { DOWNLOAD_PATH_TEMPLATES, MAPPABLE_BASE_MODELS, PATH_TEMPLATE_PLACEHOLDERS, DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/constants.js'; import {
DOWNLOAD_PATH_TEMPLATES,
MAPPABLE_BASE_MODELS,
PATH_TEMPLATE_PLACEHOLDERS,
DEFAULT_PATH_TEMPLATES,
DEFAULT_PRIORITY_TAG_CONFIG,
getMappableBaseModelsDynamic
} from '../utils/constants.js';
import { translate } from '../utils/i18nHelpers.js'; import { translate } from '../utils/i18nHelpers.js';
import { i18n } from '../i18n/index.js'; import { i18n } from '../i18n/index.js';
import { configureModelCardVideo } from '../components/shared/ModelCard.js'; import { configureModelCardVideo } from '../components/shared/ModelCard.js';
@@ -184,7 +191,9 @@ export class SettingsManager {
} }
getAvailableDownloadSkipBaseModels() { getAvailableDownloadSkipBaseModels() {
return MAPPABLE_BASE_MODELS.filter(model => model !== 'Other'); // Use dynamic base models if available, fallback to hardcoded
const models = getMappableBaseModelsDynamic();
return models.filter(model => model !== 'Other');
} }
normalizeDownloadSkipBaseModels(value) { normalizeDownloadSkipBaseModels(value) {
@@ -1517,7 +1526,7 @@ export class SettingsManager {
const row = document.createElement('div'); const row = document.createElement('div');
row.className = 'mapping-row'; row.className = 'mapping-row';
const availableModels = MAPPABLE_BASE_MODELS.filter(model => { const availableModels = getMappableBaseModelsDynamic().filter(model => {
const existingMappings = state.global.settings.base_model_path_mappings || {}; const existingMappings = state.global.settings.base_model_path_mappings || {};
return !existingMappings.hasOwnProperty(model) || model === baseModel; return !existingMappings.hasOwnProperty(model) || model === baseModel;
}); });
@@ -1619,7 +1628,7 @@ export class SettingsManager {
const currentValue = select.value; const currentValue = select.value;
// Get available models (not already mapped, except current) // Get available models (not already mapped, except current)
const availableModels = MAPPABLE_BASE_MODELS.filter(model => const availableModels = getMappableBaseModelsDynamic().filter(model =>
!existingMappings.hasOwnProperty(model) || model === currentValue !existingMappings.hasOwnProperty(model) || model === currentValue
); );

View File

@@ -50,6 +50,9 @@ export const BASE_MODELS = {
SVD: "SVD", SVD: "SVD",
LTXV: "LTXV", LTXV: "LTXV",
LTXV2: "LTXV2", LTXV2: "LTXV2",
LTXV_2_3: "LTXV 2.3",
COGVIDE_X: "CogVideoX",
MOCHI: "Mochi",
WAN_VIDEO: "Wan Video", WAN_VIDEO: "Wan Video",
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v", WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v", WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
@@ -58,7 +61,12 @@ export const BASE_MODELS = {
WAN_VIDEO_2_2_TI2V_5B: "Wan Video 2.2 TI2V-5B", WAN_VIDEO_2_2_TI2V_5B: "Wan Video 2.2 TI2V-5B",
WAN_VIDEO_2_2_T2V_A14B: "Wan Video 2.2 T2V-A14B", WAN_VIDEO_2_2_T2V_A14B: "Wan Video 2.2 T2V-A14B",
WAN_VIDEO_2_2_I2V_A14B: "Wan Video 2.2 I2V-A14B", WAN_VIDEO_2_2_I2V_A14B: "Wan Video 2.2 I2V-A14B",
WAN_VIDEO_2_5_T2V: "Wan Video 2.5 T2V",
WAN_VIDEO_2_5_I2V: "Wan Video 2.5 I2V",
HUNYUAN_VIDEO: "Hunyuan Video", HUNYUAN_VIDEO: "Hunyuan Video",
// Other models
ANIMA: "Anima",
PONY_V7: "Pony V7",
// Default // Default
UNKNOWN: "Other" UNKNOWN: "Other"
}; };
@@ -151,6 +159,9 @@ export const BASE_MODEL_ABBREVIATIONS = {
[BASE_MODELS.SVD]: 'SVD', [BASE_MODELS.SVD]: 'SVD',
[BASE_MODELS.LTXV]: 'LTXV', [BASE_MODELS.LTXV]: 'LTXV',
[BASE_MODELS.LTXV2]: 'LTV2', [BASE_MODELS.LTXV2]: 'LTV2',
[BASE_MODELS.LTXV_2_3]: 'LTX',
[BASE_MODELS.COGVIDE_X]: 'CVX',
[BASE_MODELS.MOCHI]: 'MCHI',
[BASE_MODELS.WAN_VIDEO]: 'WAN', [BASE_MODELS.WAN_VIDEO]: 'WAN',
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN', [BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN', [BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
@@ -159,8 +170,28 @@ export const BASE_MODEL_ABBREVIATIONS = {
[BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B]: 'WAN', [BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B]: 'WAN',
[BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B]: 'WAN', [BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B]: 'WAN',
[BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B]: 'WAN', [BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B]: 'WAN',
[BASE_MODELS.WAN_VIDEO_2_5_T2V]: 'WAN',
[BASE_MODELS.WAN_VIDEO_2_5_I2V]: 'WAN',
[BASE_MODELS.HUNYUAN_VIDEO]: 'HYV', [BASE_MODELS.HUNYUAN_VIDEO]: 'HYV',
// Other diffusion models
[BASE_MODELS.AURAFLOW]: 'AF',
[BASE_MODELS.CHROMA]: 'CHR',
[BASE_MODELS.PIXART_A]: 'PXA',
[BASE_MODELS.PIXART_E]: 'PXE',
[BASE_MODELS.HUNYUAN_1]: 'HY',
[BASE_MODELS.LUMINA]: 'L',
[BASE_MODELS.KOLORS]: 'KLR',
[BASE_MODELS.NOOBAI]: 'NAI',
[BASE_MODELS.ILLUSTRIOUS]: 'IL',
[BASE_MODELS.PONY]: 'PONY',
[BASE_MODELS.PONY_V7]: 'PNY7',
[BASE_MODELS.HIDREAM]: 'HID',
[BASE_MODELS.QWEN]: 'QWEN',
[BASE_MODELS.ZIMAGE_TURBO]: 'ZIT',
[BASE_MODELS.ZIMAGE_BASE]: 'ZIB',
[BASE_MODELS.ANIMA]: 'ANI',
// Default // Default
[BASE_MODELS.UNKNOWN]: 'OTH' [BASE_MODELS.UNKNOWN]: 'OTH'
}; };
@@ -349,18 +380,20 @@ export const BASE_MODEL_CATEGORIES = {
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [ 'Video Models': [
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO, BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.LTXV_2_3,
BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V, BASE_MODELS.COGVIDE_X, BASE_MODELS.MOCHI, BASE_MODELS.HUNYUAN_VIDEO,
BASE_MODELS.WAN_VIDEO, BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P, BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B, BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B, BASE_MODELS.WAN_VIDEO_2_5_T2V,
BASE_MODELS.WAN_VIDEO_2_5_I2V
], ],
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE], 'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE],
'Other Models': [ 'Other Models': [
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM, BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.PONY_V7, BASE_MODELS.HIDREAM,
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE, BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.ANIMA,
BASE_MODELS.UNKNOWN BASE_MODELS.UNKNOWN
] ]
}; };
@@ -378,3 +411,94 @@ export const DEFAULT_PRIORITY_TAG_CONFIG = {
checkpoint: DEFAULT_PRIORITY_TAG_ENTRIES.join(', '), checkpoint: DEFAULT_PRIORITY_TAG_ENTRIES.join(', '),
embedding: DEFAULT_PRIORITY_TAG_ENTRIES.join(', ') embedding: DEFAULT_PRIORITY_TAG_ENTRIES.join(', ')
}; };
// ============================================================================
// Dynamic Base Model Support
// ============================================================================
/**
* Dynamic base model cache
* Stores models fetched from Civitai API
*/
let dynamicBaseModels = null;
let dynamicBaseModelsTimestamp = null;
const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days
/**
* Set dynamic base models (called after fetching from API)
* @param {Array} models - Array of base model names
* @param {string} timestamp - ISO timestamp of fetch
*/
export function setDynamicBaseModels(models, timestamp) {
dynamicBaseModels = models;
dynamicBaseModelsTimestamp = timestamp;
}
/**
* Get dynamic base models
* @returns {Object|null} { models, timestamp } or null if not set
*/
export function getDynamicBaseModels() {
if (!dynamicBaseModels) return null;
// Check if cache is expired
if (dynamicBaseModelsTimestamp) {
const age = Date.now() - new Date(dynamicBaseModelsTimestamp).getTime();
if (age > CACHE_TTL_MS) {
dynamicBaseModels = null;
dynamicBaseModelsTimestamp = null;
return null;
}
}
return {
models: dynamicBaseModels,
timestamp: dynamicBaseModelsTimestamp
};
}
/**
* Get merged base models (hardcoded + dynamic)
* Returns unique sorted list of all available base models
* @returns {Array} Sorted array of base model names
*/
export function getMergedBaseModels() {
const hardcoded = Object.values(BASE_MODELS);
const dynamic = getDynamicBaseModels();
if (!dynamic || !dynamic.models) {
return hardcoded.sort();
}
// Merge and deduplicate
const merged = new Set([...hardcoded, ...dynamic.models]);
return Array.from(merged).sort();
}
/**
* Get mappable base models (for UI selection)
* Excludes 'Other' value
* @returns {Array} Sorted array of base model names (excluding 'Other')
*/
export function getMappableBaseModelsDynamic() {
const merged = getMergedBaseModels();
return merged.filter(model => model !== 'Other');
}
/**
* Clear dynamic base models cache
*/
export function clearDynamicBaseModels() {
dynamicBaseModels = null;
dynamicBaseModelsTimestamp = null;
}
/**
* Check if dynamic base models cache is valid
* @returns {boolean}
*/
export function isDynamicBaseModelsCacheValid() {
if (!dynamicBaseModels || !dynamicBaseModelsTimestamp) return false;
const age = Date.now() - new Date(dynamicBaseModelsTimestamp).getTime();
return age <= CACHE_TTL_MS;
}

View File

@@ -0,0 +1,133 @@
"""Tests for CivitaiBaseModelService."""
import pytest
from unittest.mock import patch
from py.services.civitai_base_model_service import CivitaiBaseModelService
class TestCivitaiBaseModelService:
"""Test suite for CivitaiBaseModelService."""
@pytest.fixture(autouse=True)
def setup_service(self):
"""Create a fresh service instance for each test."""
self.service = CivitaiBaseModelService()
# Reset cache
self.service._cache = None
self.service._cache_timestamp = None
yield
def test_generate_abbreviation_known_models(self):
"""Test abbreviation generation for known models."""
test_cases = [
("SD 1.5", "SD1"),
("SDXL 1.0", "XL"),
("Flux.1 D", "F1D"),
("Wan Video 2.5 T2V", "WAN"),
("Pony V7", "PNY7"),
("CogVideoX", "CVX"),
("Mochi", "MCHI"),
("Anima", "ANI"),
]
for model_name, expected in test_cases:
result = self.service.generate_abbreviation(model_name)
assert result == expected, (
f"Failed for {model_name}: got {result}, expected {expected}"
)
def test_generate_abbreviation_unknown_models(self):
"""Test abbreviation generation for unknown models."""
result = self.service.generate_abbreviation("New Model 2.0")
assert len(result) <= 4
assert result.isupper()
def test_generate_abbreviation_edge_cases(self):
"""Test abbreviation generation edge cases."""
assert self.service.generate_abbreviation("") == "OTH"
assert self.service.generate_abbreviation(None) == "OTH"
def test_cache_status_no_cache(self):
"""Test cache status when no cache exists."""
status = self.service.get_cache_status()
assert status["has_cache"] is False
assert status["last_updated"] is None
assert status["is_expired"] is True
assert status["age_seconds"] is None
@pytest.mark.asyncio
async def test_get_base_models_fallback(self):
"""Test that fallback to hardcoded models works."""
with patch.object(self.service, "_fetch_from_civitai", return_value=None):
result = await self.service.get_base_models()
assert result["source"] == "fallback"
assert len(result["models"]) > 0
assert result["hardcoded_count"] > 0
assert result["remote_count"] == 0
@pytest.mark.asyncio
async def test_get_base_models_from_api(self):
"""Test fetching models from API."""
mock_models = {"SD 1.5", "SDXL 1.0", "New Model"}
with patch.object(
self.service, "_fetch_from_civitai", return_value=mock_models
):
result = await self.service.get_base_models()
assert result["source"] == "api"
assert result["remote_count"] == 3
assert "New Model" in result["models"]
@pytest.mark.asyncio
async def test_get_base_models_uses_cache(self):
"""Test that cached data is used when available and not expired."""
# First call - populate cache
mock_models = {"SD 1.5", "SDXL 1.0"}
with patch.object(
self.service, "_fetch_from_civitai", return_value=mock_models
):
await self.service.get_base_models()
# Second call - should use cache
with patch.object(self.service, "_fetch_from_civitai") as mock_fetch:
result = await self.service.get_base_models()
mock_fetch.assert_not_called()
assert result["source"] == "cache"
@pytest.mark.asyncio
async def test_refresh_cache(self):
"""Test force refresh clears cache and fetches fresh data."""
# Populate cache
mock_models = {"SD 1.5"}
with patch.object(
self.service, "_fetch_from_civitai", return_value=mock_models
):
await self.service.get_base_models()
# Force refresh with different data
new_models = {"SD 1.5", "SDXL 1.0", "New Model"}
with patch.object(self.service, "_fetch_from_civitai", return_value=new_models):
result = await self.service.refresh_cache()
assert result["source"] == "api"
assert result["remote_count"] == 3
def test_get_model_categories(self):
"""Test model categories are returned."""
categories = self.service.get_model_categories()
assert "Stable Diffusion 1.x" in categories
assert "Video Models" in categories
assert "Flux Models" in categories
assert "Other Models" in categories
# Check that video models include new additions
video_models = categories["Video Models"]
assert "CogVideoX" in video_models
assert "Mochi" in video_models
assert "Wan Video 2.5 T2V" in video_models