diff --git a/py/routes/handlers/base_model_handlers.py b/py/routes/handlers/base_model_handlers.py new file mode 100644 index 00000000..834ecedb --- /dev/null +++ b/py/routes/handlers/base_model_handlers.py @@ -0,0 +1,141 @@ +"""Handlers for base model related endpoints.""" + +from __future__ import annotations + +import logging +from typing import Any, Awaitable, Callable, Dict + +from aiohttp import web + +from ...services.civitai_base_model_service import get_civitai_base_model_service + +logger = logging.getLogger(__name__) + + +class BaseModelHandlerSet: + """Collection of handlers for base model operations.""" + + def __init__( + self, + base_model_service_factory: Callable[[], Any] = get_civitai_base_model_service, + ) -> None: + self._base_model_service_factory = base_model_service_factory + + def to_route_mapping( + self, + ) -> Dict[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: + """Return mapping of route names to handler methods.""" + return { + "get_base_models": self.get_base_models, + "refresh_base_models": self.refresh_base_models, + "get_base_model_categories": self.get_base_model_categories, + "get_base_model_cache_status": self.get_base_model_cache_status, + } + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get merged base models (hardcoded + remote from Civitai). + + Query Parameters: + refresh: If 'true', force refresh from API + + Returns: + JSON response with: + - models: List of base model names + - source: 'cache', 'api', or 'fallback' + - last_updated: ISO timestamp + - counts: hardcoded_count, remote_count, merged_count + """ + try: + service = await self._base_model_service_factory() + + # Check for refresh parameter + force_refresh = request.query.get("refresh", "").lower() == "true" + + result = await service.get_base_models(force_refresh=force_refresh) + + return web.json_response( + { + "success": True, + "data": result, + } + ) + + except Exception as e: + logger.error(f"Error in get_base_models: {e}") + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) + + async def refresh_base_models(self, request: web.Request) -> web.Response: + """Force refresh base models from Civitai API. + + Returns: + JSON response with refreshed data + """ + try: + service = await self._base_model_service_factory() + result = await service.refresh_cache() + + return web.json_response( + { + "success": True, + "data": result, + "message": "Base models cache refreshed successfully", + } + ) + + except Exception as e: + logger.error(f"Error in refresh_base_models: {e}") + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) + + async def get_base_model_categories(self, request: web.Request) -> web.Response: + """Get categorized base models. + + Returns: + JSON response with categorized models + """ + try: + service = await self._base_model_service_factory() + categories = service.get_model_categories() + + return web.json_response( + { + "success": True, + "data": categories, + } + ) + + except Exception as e: + logger.error(f"Error in get_base_model_categories: {e}") + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) + + async def get_base_model_cache_status(self, request: web.Request) -> web.Response: + """Get cache status for base models. + + Returns: + JSON response with cache status + """ + try: + service = await self._base_model_service_factory() + status = service.get_cache_status() + + return web.json_response( + { + "success": True, + "data": status, + } + ) + + except Exception as e: + logger.error(f"Error in get_base_model_cache_status: {e}") + return web.json_response( + {"success": False, "error": str(e)}, + status=500, + ) diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 79658b7f..3f9d0ffa 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -40,6 +40,7 @@ from ...utils.civitai_utils import rewrite_preview_url from ...utils.example_images_paths import is_valid_example_images_root from ...utils.lora_metadata import extract_trained_words from ...utils.usage_stats import UsageStats +from .base_model_handlers import BaseModelHandlerSet logger = logging.getLogger(__name__) @@ -1618,6 +1619,7 @@ class MiscHandlerSet: custom_words: CustomWordsHandler, supporters: SupportersHandler, example_workflows: ExampleWorkflowsHandler, + base_model: BaseModelHandlerSet, ) -> None: self.health = health self.settings = settings @@ -1632,6 +1634,7 @@ class MiscHandlerSet: self.custom_words = custom_words self.supporters = supporters self.example_workflows = example_workflows + self.base_model = base_model def to_route_mapping( self, @@ -1663,6 +1666,11 @@ class MiscHandlerSet: "get_supporters": self.supporters.get_supporters, "get_example_workflows": self.example_workflows.get_example_workflows, "get_example_workflow": self.example_workflows.get_example_workflow, + # Base model handlers + "get_base_models": self.base_model.get_base_models, + "refresh_base_models": self.base_model.refresh_base_models, + "get_base_model_categories": self.base_model.get_base_model_categories, + "get_base_model_cache_status": self.base_model.get_base_model_cache_status, } diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index 5ab34b3b..e77ed579 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -56,6 +56,15 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition( "GET", "/api/lm/example-workflows/{filename}", "get_example_workflow" ), + # Base model management routes + RouteDefinition("GET", "/api/lm/base-models", "get_base_models"), + RouteDefinition("POST", "/api/lm/base-models/refresh", "refresh_base_models"), + RouteDefinition( + "GET", "/api/lm/base-models/categories", "get_base_model_categories" + ), + RouteDefinition( + "GET", "/api/lm/base-models/cache-status", "get_base_model_cache_status" + ), ) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 73494b42..a800bc86 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -35,6 +35,7 @@ from .handlers.misc_handlers import ( UsageStatsHandler, build_service_registry_adapter, ) +from .handlers.base_model_handlers import BaseModelHandlerSet from .misc_route_registrar import MiscRouteRegistrar logger = logging.getLogger(__name__) @@ -128,6 +129,7 @@ class MiscRoutes: custom_words = CustomWordsHandler() supporters = SupportersHandler() example_workflows = ExampleWorkflowsHandler() + base_model = BaseModelHandlerSet() return self._handler_set_factory( health=health, @@ -143,6 +145,7 @@ class MiscRoutes: custom_words=custom_words, supporters=supporters, example_workflows=example_workflows, + base_model=base_model, ) diff --git a/py/services/civitai_base_model_service.py b/py/services/civitai_base_model_service.py new file mode 100644 index 00000000..b07f3840 --- /dev/null +++ b/py/services/civitai_base_model_service.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import re +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Set, Tuple + +from ..utils.constants import SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS +from .downloader import get_downloader + +logger = logging.getLogger(__name__) + + +class CivitaiBaseModelService: + """Service for fetching and managing Civitai base models. + + This service provides: + - Fetching base models from Civitai API + - Caching with TTL (7 days default) + - Merging hardcoded and remote base models + - Generating abbreviations for new/unknown models + """ + + _instance: Optional[CivitaiBaseModelService] = None + _lock = asyncio.Lock() + + # Default TTL for cache in seconds (7 days) + DEFAULT_CACHE_TTL = 7 * 24 * 60 * 60 + + # Civitai API endpoint for enums + CIVITAI_ENUMS_URL = "https://civitai.com/api/v1/enums" + + @classmethod + async def get_instance(cls) -> CivitaiBaseModelService: + """Get singleton instance of the service.""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + """Initialize the service.""" + if hasattr(self, "_initialized"): + return + self._initialized = True + + # Cache storage + self._cache: Optional[Dict[str, Any]] = None + self._cache_timestamp: Optional[datetime] = None + self._cache_ttl = self.DEFAULT_CACHE_TTL + + # Hardcoded models for fallback + self._hardcoded_models = set(SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS) + + logger.info("CivitaiBaseModelService initialized") + + async def get_base_models(self, force_refresh: bool = False) -> Dict[str, Any]: + """Get merged base models (hardcoded + remote). + + Args: + force_refresh: If True, fetch from API regardless of cache state. + + Returns: + Dictionary containing: + - models: List of merged base model names + - source: 'cache', 'api', or 'fallback' + - last_updated: ISO timestamp of last successful API fetch + - hardcoded_count: Number of hardcoded models + - remote_count: Number of remote models + - merged_count: Total unique models + """ + # Check if cache is valid + if not force_refresh and self._is_cache_valid(): + logger.debug("Returning cached base models") + return self._build_response("cache") + + # Try to fetch from API + try: + remote_models = await self._fetch_from_civitai() + if remote_models: + self._update_cache(remote_models) + return self._build_response("api") + except Exception as e: + logger.error(f"Failed to fetch base models from Civitai: {e}") + + # Fallback to hardcoded models + return self._build_response("fallback") + + async def refresh_cache(self) -> Dict[str, Any]: + """Force refresh the cache from Civitai API. + + Returns: + Response dict same as get_base_models() + """ + return await self.get_base_models(force_refresh=True) + + def get_cache_status(self) -> Dict[str, Any]: + """Get current cache status. + + Returns: + Dictionary containing: + - has_cache: Whether cache exists + - last_updated: ISO timestamp or None + - is_expired: Whether cache is expired + - ttl_seconds: TTL in seconds + - age_seconds: Age of cache in seconds (if exists) + """ + if self._cache is None or self._cache_timestamp is None: + return { + "has_cache": False, + "last_updated": None, + "is_expired": True, + "ttl_seconds": self._cache_ttl, + "age_seconds": None, + } + + age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds() + return { + "has_cache": True, + "last_updated": self._cache_timestamp.isoformat(), + "is_expired": age > self._cache_ttl, + "ttl_seconds": self._cache_ttl, + "age_seconds": int(age), + } + + def generate_abbreviation(self, model_name: str) -> str: + """Generate abbreviation for a base model name. + + Algorithm: + 1. Extract version patterns (e.g., "2.5" from "Wan Video 2.5") + 2. Extract main acronym (e.g., "SD" from "SD 1.5") + 3. Handle special cases (Flux, Wan, etc.) + 4. Fallback to first letters of words (max 4 chars) + + Args: + model_name: Full base model name + + Returns: + Generated abbreviation (max 4 characters) + """ + if not model_name or not isinstance(model_name, str): + return "OTH" + + name = model_name.strip() + if not name: + return "OTH" + + # Check if it's already in hardcoded abbreviations + # This is a simplified check - in practice you'd have a mapping + lower_name = name.lower() + + # Special cases + special_cases = { + "sd 1.4": "SD1", + "sd 1.5": "SD1", + "sd 1.5 lcm": "SD1", + "sd 1.5 hyper": "SD1", + "sd 2.0": "SD2", + "sd 2.1": "SD2", + "sd 3": "SD3", + "sd 3.5": "SD3", + "sd 3.5 medium": "SD3", + "sd 3.5 large": "SD3", + "sd 3.5 large turbo": "SD3", + "sdxl 1.0": "XL", + "sdxl lightning": "XL", + "sdxl hyper": "XL", + "flux.1 d": "F1D", + "flux.1 s": "F1S", + "flux.1 krea": "F1KR", + "flux.1 kontext": "F1KX", + "flux.2 d": "F2D", + "flux.2 klein 9b": "FK9", + "flux.2 klein 9b-base": "FK9B", + "flux.2 klein 4b": "FK4", + "flux.2 klein 4b-base": "FK4B", + "auraflow": "AF", + "chroma": "CHR", + "pixart a": "PXA", + "pixart e": "PXE", + "hunyuan 1": "HY", + "hunyuan video": "HYV", + "lumina": "L", + "kolors": "KLR", + "noobai": "NAI", + "illustrious": "IL", + "pony": "PONY", + "pony v7": "PNY7", + "hidream": "HID", + "qwen": "QWEN", + "zimageturbo": "ZIT", + "zimagebase": "ZIB", + "anima": "ANI", + "svd": "SVD", + "ltxv": "LTXV", + "ltxv2": "LTV2", + "ltxv 2.3": "LTX", + "cogvideox": "CVX", + "mochi": "MCHI", + "wan video": "WAN", + "wan video 1.3b t2v": "WAN", + "wan video 14b t2v": "WAN", + "wan video 14b i2v 480p": "WAN", + "wan video 14b i2v 720p": "WAN", + "wan video 2.2 ti2v-5b": "WAN", + "wan video 2.2 t2v-a14b": "WAN", + "wan video 2.2 i2v-a14b": "WAN", + "wan video 2.5 t2v": "WAN", + "wan video 2.5 i2v": "WAN", + } + + if lower_name in special_cases: + return special_cases[lower_name] + + # Try to extract acronym from version pattern + # e.g., "Model Name 2.5" -> "MN25" + version_match = re.search(r"(\d+(?:\.\d+)?)", name) + version = version_match.group(1) if version_match else "" + + # Remove version and common words + words = re.sub(r"\d+(?:\.\d+)?", "", name) + words = re.sub( + r"\b(model|video|diffusion|checkpoint|textualinversion)\b", + "", + words, + flags=re.I, + ) + words = words.strip() + + # Get first letters of remaining words + tokens = re.findall(r"[A-Za-z]+", words) + if tokens: + # Build abbreviation from first letters + abbrev = "".join(token[0].upper() for token in tokens) + # Add version if present + if version: + # Clean version (remove dots for abbreviation) + version_clean = version.replace(".", "") + abbrev = abbrev[: 4 - len(version_clean)] + version_clean + return abbrev[:4] + + # Final fallback: just take first 4 alphanumeric chars + alphanumeric = re.sub(r"[^A-Za-z0-9]", "", name) + if alphanumeric: + return alphanumeric[:4].upper() + + return "OTH" + + async def _fetch_from_civitai(self) -> Optional[Set[str]]: + """Fetch base models from Civitai API. + + Returns: + Set of base model names, or None if failed + """ + try: + downloader = await get_downloader() + success, result = await downloader.make_request( + "GET", + self.CIVITAI_ENUMS_URL, + use_auth=False, # enums endpoint doesn't require auth + ) + + if not success: + logger.warning(f"Failed to fetch enums from Civitai: {result}") + return None + + if isinstance(result, str): + data = json.loads(result) + else: + data = result + + # Extract base models from response + base_models = set() + + # Use ActiveBaseModel if available (recommended active models) + if "ActiveBaseModel" in data: + base_models.update(data["ActiveBaseModel"]) + logger.info(f"Fetched {len(base_models)} models from ActiveBaseModel") + # Fallback to full BaseModel list + elif "BaseModel" in data: + base_models.update(data["BaseModel"]) + logger.info(f"Fetched {len(base_models)} models from BaseModel") + else: + logger.warning("No base model data found in Civitai response") + return None + + return base_models + + except Exception as e: + logger.error(f"Error fetching from Civitai: {e}") + return None + + def _update_cache(self, remote_models: Set[str]) -> None: + """Update internal cache with fetched models. + + Args: + remote_models: Set of base model names from API + """ + self._cache = { + "remote_models": sorted(remote_models), + "hardcoded_models": sorted(self._hardcoded_models), + } + self._cache_timestamp = datetime.now(timezone.utc) + logger.info(f"Cache updated with {len(remote_models)} remote models") + + def _is_cache_valid(self) -> bool: + """Check if current cache is valid (not expired). + + Returns: + True if cache exists and is not expired + """ + if self._cache is None or self._cache_timestamp is None: + return False + + age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds() + return age <= self._cache_ttl + + def _build_response(self, source: str) -> Dict[str, Any]: + """Build response dictionary. + + Args: + source: 'cache', 'api', or 'fallback' + + Returns: + Response dictionary + """ + if source == "fallback" or self._cache is None: + # Use only hardcoded models + merged = sorted(self._hardcoded_models) + return { + "models": merged, + "source": source, + "last_updated": None, + "hardcoded_count": len(self._hardcoded_models), + "remote_count": 0, + "merged_count": len(merged), + } + + # Merge hardcoded and remote models + remote_set = set(self._cache.get("remote_models", [])) + merged = sorted(self._hardcoded_models | remote_set) + + return { + "models": merged, + "source": source, + "last_updated": self._cache_timestamp.isoformat() + if self._cache_timestamp + else None, + "hardcoded_count": len(self._hardcoded_models), + "remote_count": len(remote_set), + "merged_count": len(merged), + } + + def get_model_categories(self) -> Dict[str, List[str]]: + """Get categorized base models. + + Returns: + Dictionary mapping category names to lists of model names + """ + # Define category patterns + categories = { + "Stable Diffusion 1.x": ["SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper"], + "Stable Diffusion 2.x": ["SD 2.0", "SD 2.1"], + "Stable Diffusion 3.x": [ + "SD 3", + "SD 3.5", + "SD 3.5 Medium", + "SD 3.5 Large", + "SD 3.5 Large Turbo", + ], + "SDXL": ["SDXL 1.0", "SDXL Lightning", "SDXL Hyper"], + "Flux Models": [ + "Flux.1 D", + "Flux.1 S", + "Flux.1 Krea", + "Flux.1 Kontext", + "Flux.2 D", + "Flux.2 Klein 9B", + "Flux.2 Klein 9B-base", + "Flux.2 Klein 4B", + "Flux.2 Klein 4B-base", + ], + "Video Models": [ + "SVD", + "LTXV", + "LTXV2", + "LTXV 2.3", + "CogVideoX", + "Mochi", + "Hunyuan Video", + "Wan Video", + "Wan Video 1.3B t2v", + "Wan Video 14B t2v", + "Wan Video 14B i2v 480p", + "Wan Video 14B i2v 720p", + "Wan Video 2.2 TI2V-5B", + "Wan Video 2.2 T2V-A14B", + "Wan Video 2.2 I2V-A14B", + "Wan Video 2.5 T2V", + "Wan Video 2.5 I2V", + ], + "Other Models": [ + "Illustrious", + "Pony", + "Pony V7", + "HiDream", + "Qwen", + "AuraFlow", + "Chroma", + "ZImageTurbo", + "ZImageBase", + "PixArt a", + "PixArt E", + "Hunyuan 1", + "Lumina", + "Kolors", + "NoobAI", + "Anima", + ], + } + + return categories + + +# Convenience function for getting the singleton instance +async def get_civitai_base_model_service() -> CivitaiBaseModelService: + """Get the singleton instance of CivitaiBaseModelService.""" + return await CivitaiBaseModelService.get_instance() diff --git a/py/utils/constants.py b/py/utils/constants.py index 6215c946..c8a44528 100644 --- a/py/utils/constants.py +++ b/py/utils/constants.py @@ -110,6 +110,8 @@ DIFFUSION_MODEL_BASE_MODELS = frozenset( "Wan Video 2.2 T2V-A14B", "Wan Video 2.5 T2V", "Wan Video 2.5 I2V", + "CogVideoX", + "Mochi", "Qwen", ] ) @@ -151,6 +153,7 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset( "NoobAI", "Illustrious", "Pony", + "Pony V7", "HiDream", "Qwen", "ZImageTurbo", @@ -158,6 +161,9 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset( "SVD", "LTXV", "LTXV2", + "LTXV 2.3", + "CogVideoX", + "Mochi", "Wan Video", "Wan Video 1.3B t2v", "Wan Video 14B t2v", @@ -166,6 +172,9 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset( "Wan Video 2.2 TI2V-5B", "Wan Video 2.2 T2V-A14B", "Wan Video 2.2 I2V-A14B", + "Wan Video 2.5 T2V", + "Wan Video 2.5 I2V", "Hunyuan Video", + "Anima", ] ) diff --git a/refs/enums.json b/refs/enums.json new file mode 100644 index 00000000..6a7559b6 --- /dev/null +++ b/refs/enums.json @@ -0,0 +1,167 @@ +{ + "ModelType": [ + "Checkpoint", + "TextualInversion", + "Hypernetwork", + "AestheticGradient", + "LORA", + "LoCon", + "DoRA", + "Controlnet", + "Upscaler", + "MotionModule", + "VAE", + "Poses", + "Wildcards", + "Workflows", + "Detection", + "Other" + ], + "ModelFileType": [ + "Model", + "Text Encoder", + "Pruned Model", + "Negative", + "Training Data", + "VAE", + "Config", + "Archive" + ], + "ActiveBaseModel": [ + "Anima", + "AuraFlow", + "Chroma", + "CogVideoX", + "Flux.1 S", + "Flux.1 D", + "Flux.1 Krea", + "Flux.1 Kontext", + "Flux.2 D", + "Flux.2 Klein 9B", + "Flux.2 Klein 9B-base", + "Flux.2 Klein 4B", + "Flux.2 Klein 4B-base", + "HiDream", + "Hunyuan 1", + "Hunyuan Video", + "Illustrious", + "Kolors", + "LTXV", + "LTXV2", + "LTXV 2.3", + "Lumina", + "Mochi", + "NoobAI", + "Other", + "PixArt a", + "PixArt E", + "Pony", + "Pony V7", + "Qwen", + "SD 1.4", + "SD 1.5", + "SD 1.5 LCM", + "SD 1.5 Hyper", + "SD 2.0", + "SD 2.1", + "SDXL 1.0", + "SDXL Lightning", + "SDXL Hyper", + "Wan Video 1.3B t2v", + "Wan Video 14B t2v", + "Wan Video 14B i2v 480p", + "Wan Video 14B i2v 720p", + "Wan Video 2.2 TI2V-5B", + "Wan Video 2.2 I2V-A14B", + "Wan Video 2.2 T2V-A14B", + "Wan Video 2.5 T2V", + "Wan Video 2.5 I2V", + "ZImageTurbo", + "ZImageBase" + ], + "BaseModel": [ + "Anima", + "AuraFlow", + "Chroma", + "CogVideoX", + "Flux.1 S", + "Flux.1 D", + "Flux.1 Krea", + "Flux.1 Kontext", + "Flux.2 D", + "Flux.2 Klein 9B", + "Flux.2 Klein 9B-base", + "Flux.2 Klein 4B", + "Flux.2 Klein 4B-base", + "HiDream", + "Hunyuan 1", + "Hunyuan Video", + "Illustrious", + "Imagen4", + "Kling", + "Kolors", + "LTXV", + "LTXV2", + "LTXV 2.3", + "Lumina", + "Mochi", + "Nano Banana", + "NoobAI", + "ODOR", + "OpenAI", + "Other", + "PixArt a", + "PixArt E", + "Playground v2", + "Pony", + "Pony V7", + "Qwen", + "Stable Cascade", + "SD 1.4", + "SD 1.5", + "SD 1.5 LCM", + "SD 1.5 Hyper", + "SD 2.0", + "SD 2.0 768", + "SD 2.1", + "SD 2.1 768", + "SD 2.1 Unclip", + "SD 3", + "SD 3.5", + "SD 3.5 Large", + "SD 3.5 Large Turbo", + "SD 3.5 Medium", + "Sora 2", + "SDXL 0.9", + "SDXL 1.0", + "SDXL 1.0 LCM", + "SDXL Lightning", + "SDXL Hyper", + "SDXL Turbo", + "SDXL Distilled", + "Seedance", + "Seedream", + "SVD", + "SVD XT", + "Veo 3", + "Vidu Q1", + "Wan Video", + "Wan Video 1.3B t2v", + "Wan Video 14B t2v", + "Wan Video 14B i2v 480p", + "Wan Video 14B i2v 720p", + "Wan Video 2.2 TI2V-5B", + "Wan Video 2.2 I2V-A14B", + "Wan Video 2.2 T2V-A14B", + "Wan Video 2.5 T2V", + "Wan Video 2.5 I2V", + "ZImageTurbo", + "ZImageBase" + ], + "BaseModelType": [ + "Standard", + "Inpainting", + "Refiner", + "Pix2Pix" + ] +} \ No newline at end of file diff --git a/static/js/api/civitaiBaseModelApi.js b/static/js/api/civitaiBaseModelApi.js new file mode 100644 index 00000000..e2f52274 --- /dev/null +++ b/static/js/api/civitaiBaseModelApi.js @@ -0,0 +1,164 @@ +/** + * API client for Civitai base model management + * Handles fetching and refreshing base models from Civitai API + */ + +import { showToast } from '../utils/uiHelpers.js'; + +const BASE_MODEL_ENDPOINTS = { + getModels: '/api/lm/base-models', + refresh: '/api/lm/base-models/refresh', + categories: '/api/lm/base-models/categories', + cacheStatus: '/api/lm/base-models/cache-status', +}; + +/** + * Civitai Base Model API Client + */ +export class CivitaiBaseModelApi { + constructor() { + this.cache = null; + this.cacheTimestamp = null; + } + + /** + * Get base models (with caching) + * @param {boolean} forceRefresh - Force refresh from API + * @returns {Promise} Response with models, source, and counts + */ + async getBaseModels(forceRefresh = false) { + try { + const url = new URL(BASE_MODEL_ENDPOINTS.getModels, window.location.origin); + if (forceRefresh) { + url.searchParams.append('refresh', 'true'); + } + + const response = await fetch(url); + if (!response.ok) { + throw new Error(`Failed to fetch base models: ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + this.cache = data.data; + this.cacheTimestamp = Date.now(); + return data.data; + } else { + throw new Error(data.error || 'Failed to fetch base models'); + } + } catch (error) { + console.error('Error fetching base models:', error); + showToast('Failed to fetch base models', { message: error.message }, 'error'); + throw error; + } + } + + /** + * Force refresh base models from Civitai API + * @returns {Promise} Refreshed data + */ + async refreshBaseModels() { + try { + const response = await fetch(BASE_MODEL_ENDPOINTS.refresh, { + method: 'POST', + headers: { 'Content-Type': 'application/json' } + }); + + if (!response.ok) { + throw new Error(`Failed to refresh base models: ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + this.cache = data.data; + this.cacheTimestamp = Date.now(); + showToast('Base models refreshed successfully', {}, 'success'); + return data.data; + } else { + throw new Error(data.error || 'Failed to refresh base models'); + } + } catch (error) { + console.error('Error refreshing base models:', error); + showToast('Failed to refresh base models', { message: error.message }, 'error'); + throw error; + } + } + + /** + * Get base model categories + * @returns {Promise} Categories with model lists + */ + async getCategories() { + try { + const response = await fetch(BASE_MODEL_ENDPOINTS.categories); + if (!response.ok) { + throw new Error(`Failed to fetch categories: ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + return data.data; + } else { + throw new Error(data.error || 'Failed to fetch categories'); + } + } catch (error) { + console.error('Error fetching categories:', error); + throw error; + } + } + + /** + * Get cache status + * @returns {Promise} Cache status information + */ + async getCacheStatus() { + try { + const response = await fetch(BASE_MODEL_ENDPOINTS.cacheStatus); + if (!response.ok) { + throw new Error(`Failed to fetch cache status: ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + return data.data; + } else { + throw new Error(data.error || 'Failed to fetch cache status'); + } + } catch (error) { + console.error('Error fetching cache status:', error); + throw error; + } + } + + /** + * Get cached models (if available) + * @returns {Object|null} Cached data or null + */ + getCachedModels() { + return this.cache; + } + + /** + * Check if cache is available + * @returns {boolean} + */ + hasCache() { + return this.cache !== null; + } + + /** + * Get cache age in milliseconds + * @returns {number|null} Age in ms or null if no cache + */ + getCacheAge() { + if (!this.cacheTimestamp) return null; + return Date.now() - this.cacheTimestamp; + } +} + +// Export singleton instance +export const civitaiBaseModelApi = new CivitaiBaseModelApi(); diff --git a/static/js/core.js b/static/js/core.js index 27175114..3a018d68 100644 --- a/static/js/core.js +++ b/static/js/core.js @@ -17,6 +17,8 @@ import { onboardingManager } from './managers/OnboardingManager.js'; import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js'; import { createPageContextMenu, createGlobalContextMenu } from './components/ContextMenu/index.js'; import { initializeEventManagement } from './utils/eventManagementInit.js'; +import { civitaiBaseModelApi } from './api/civitaiBaseModelApi.js'; +import { setDynamicBaseModels } from './utils/constants.js'; // Core application class export class AppCore { @@ -42,6 +44,10 @@ export class AppCore { await settingsManager.waitForInitialization(); console.log('AppCore: Settings initialized'); + // Initialize dynamic base models (async, non-blocking) + console.log('AppCore: Initializing dynamic base models...'); + this.initializeDynamicBaseModels(); + // Initialize managers state.loadingManager = new LoadingManager(); modalManager.initialize(); @@ -116,6 +122,21 @@ export class AppCore { window.globalContextMenuInstance = createGlobalContextMenu(); } } + + // Initialize dynamic base models from Civitai API + // This is non-blocking - runs in background + async initializeDynamicBaseModels() { + try { + const result = await civitaiBaseModelApi.getBaseModels(); + if (result && result.models) { + setDynamicBaseModels(result.models, result.last_updated); + console.log(`AppCore: Loaded ${result.merged_count} base models (${result.hardcoded_count} hardcoded + ${result.remote_count} remote)`); + } + } catch (error) { + console.warn('AppCore: Failed to load dynamic base models:', error); + // Non-critical error - app continues with hardcoded models + } + } } // Create and export a singleton instance diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 9d7195a0..d1b07ff4 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -2,7 +2,14 @@ import { modalManager } from './ModalManager.js'; import { showToast } from '../utils/uiHelpers.js'; import { state, createDefaultSettings } from '../state/index.js'; import { resetAndReload } from '../api/modelApiFactory.js'; -import { DOWNLOAD_PATH_TEMPLATES, MAPPABLE_BASE_MODELS, PATH_TEMPLATE_PLACEHOLDERS, DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/constants.js'; +import { + DOWNLOAD_PATH_TEMPLATES, + MAPPABLE_BASE_MODELS, + PATH_TEMPLATE_PLACEHOLDERS, + DEFAULT_PATH_TEMPLATES, + DEFAULT_PRIORITY_TAG_CONFIG, + getMappableBaseModelsDynamic +} from '../utils/constants.js'; import { translate } from '../utils/i18nHelpers.js'; import { i18n } from '../i18n/index.js'; import { configureModelCardVideo } from '../components/shared/ModelCard.js'; @@ -184,7 +191,9 @@ export class SettingsManager { } getAvailableDownloadSkipBaseModels() { - return MAPPABLE_BASE_MODELS.filter(model => model !== 'Other'); + // Use dynamic base models if available, fallback to hardcoded + const models = getMappableBaseModelsDynamic(); + return models.filter(model => model !== 'Other'); } normalizeDownloadSkipBaseModels(value) { @@ -1517,7 +1526,7 @@ export class SettingsManager { const row = document.createElement('div'); row.className = 'mapping-row'; - const availableModels = MAPPABLE_BASE_MODELS.filter(model => { + const availableModels = getMappableBaseModelsDynamic().filter(model => { const existingMappings = state.global.settings.base_model_path_mappings || {}; return !existingMappings.hasOwnProperty(model) || model === baseModel; }); @@ -1619,7 +1628,7 @@ export class SettingsManager { const currentValue = select.value; // Get available models (not already mapped, except current) - const availableModels = MAPPABLE_BASE_MODELS.filter(model => + const availableModels = getMappableBaseModelsDynamic().filter(model => !existingMappings.hasOwnProperty(model) || model === currentValue ); diff --git a/static/js/utils/constants.js b/static/js/utils/constants.js index 2c50deda..6630d5fe 100644 --- a/static/js/utils/constants.js +++ b/static/js/utils/constants.js @@ -50,6 +50,9 @@ export const BASE_MODELS = { SVD: "SVD", LTXV: "LTXV", LTXV2: "LTXV2", + LTXV_2_3: "LTXV 2.3", + COGVIDE_X: "CogVideoX", + MOCHI: "Mochi", WAN_VIDEO: "Wan Video", WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v", WAN_VIDEO_14B_T2V: "Wan Video 14B t2v", @@ -58,7 +61,12 @@ export const BASE_MODELS = { WAN_VIDEO_2_2_TI2V_5B: "Wan Video 2.2 TI2V-5B", WAN_VIDEO_2_2_T2V_A14B: "Wan Video 2.2 T2V-A14B", WAN_VIDEO_2_2_I2V_A14B: "Wan Video 2.2 I2V-A14B", + WAN_VIDEO_2_5_T2V: "Wan Video 2.5 T2V", + WAN_VIDEO_2_5_I2V: "Wan Video 2.5 I2V", HUNYUAN_VIDEO: "Hunyuan Video", + // Other models + ANIMA: "Anima", + PONY_V7: "Pony V7", // Default UNKNOWN: "Other" }; @@ -151,6 +159,9 @@ export const BASE_MODEL_ABBREVIATIONS = { [BASE_MODELS.SVD]: 'SVD', [BASE_MODELS.LTXV]: 'LTXV', [BASE_MODELS.LTXV2]: 'LTV2', + [BASE_MODELS.LTXV_2_3]: 'LTX', + [BASE_MODELS.COGVIDE_X]: 'CVX', + [BASE_MODELS.MOCHI]: 'MCHI', [BASE_MODELS.WAN_VIDEO]: 'WAN', [BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN', [BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN', @@ -159,8 +170,28 @@ export const BASE_MODEL_ABBREVIATIONS = { [BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B]: 'WAN', [BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B]: 'WAN', [BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B]: 'WAN', + [BASE_MODELS.WAN_VIDEO_2_5_T2V]: 'WAN', + [BASE_MODELS.WAN_VIDEO_2_5_I2V]: 'WAN', [BASE_MODELS.HUNYUAN_VIDEO]: 'HYV', + // Other diffusion models + [BASE_MODELS.AURAFLOW]: 'AF', + [BASE_MODELS.CHROMA]: 'CHR', + [BASE_MODELS.PIXART_A]: 'PXA', + [BASE_MODELS.PIXART_E]: 'PXE', + [BASE_MODELS.HUNYUAN_1]: 'HY', + [BASE_MODELS.LUMINA]: 'L', + [BASE_MODELS.KOLORS]: 'KLR', + [BASE_MODELS.NOOBAI]: 'NAI', + [BASE_MODELS.ILLUSTRIOUS]: 'IL', + [BASE_MODELS.PONY]: 'PONY', + [BASE_MODELS.PONY_V7]: 'PNY7', + [BASE_MODELS.HIDREAM]: 'HID', + [BASE_MODELS.QWEN]: 'QWEN', + [BASE_MODELS.ZIMAGE_TURBO]: 'ZIT', + [BASE_MODELS.ZIMAGE_BASE]: 'ZIB', + [BASE_MODELS.ANIMA]: 'ANI', + // Default [BASE_MODELS.UNKNOWN]: 'OTH' }; @@ -349,18 +380,20 @@ export const BASE_MODEL_CATEGORIES = { 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], 'Video Models': [ - BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO, - BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V, + BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.LTXV_2_3, + BASE_MODELS.COGVIDE_X, BASE_MODELS.MOCHI, BASE_MODELS.HUNYUAN_VIDEO, + BASE_MODELS.WAN_VIDEO, BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V, BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P, BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B, - BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B + BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B, BASE_MODELS.WAN_VIDEO_2_5_T2V, + BASE_MODELS.WAN_VIDEO_2_5_I2V ], 'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE], 'Other Models': [ - BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM, + BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.PONY_V7, BASE_MODELS.HIDREAM, BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE, BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, - BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, + BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.ANIMA, BASE_MODELS.UNKNOWN ] }; @@ -378,3 +411,94 @@ export const DEFAULT_PRIORITY_TAG_CONFIG = { checkpoint: DEFAULT_PRIORITY_TAG_ENTRIES.join(', '), embedding: DEFAULT_PRIORITY_TAG_ENTRIES.join(', ') }; + +// ============================================================================ +// Dynamic Base Model Support +// ============================================================================ + +/** + * Dynamic base model cache + * Stores models fetched from Civitai API + */ +let dynamicBaseModels = null; +let dynamicBaseModelsTimestamp = null; +const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days + +/** + * Set dynamic base models (called after fetching from API) + * @param {Array} models - Array of base model names + * @param {string} timestamp - ISO timestamp of fetch + */ +export function setDynamicBaseModels(models, timestamp) { + dynamicBaseModels = models; + dynamicBaseModelsTimestamp = timestamp; +} + +/** + * Get dynamic base models + * @returns {Object|null} { models, timestamp } or null if not set + */ +export function getDynamicBaseModels() { + if (!dynamicBaseModels) return null; + + // Check if cache is expired + if (dynamicBaseModelsTimestamp) { + const age = Date.now() - new Date(dynamicBaseModelsTimestamp).getTime(); + if (age > CACHE_TTL_MS) { + dynamicBaseModels = null; + dynamicBaseModelsTimestamp = null; + return null; + } + } + + return { + models: dynamicBaseModels, + timestamp: dynamicBaseModelsTimestamp + }; +} + +/** + * Get merged base models (hardcoded + dynamic) + * Returns unique sorted list of all available base models + * @returns {Array} Sorted array of base model names + */ +export function getMergedBaseModels() { + const hardcoded = Object.values(BASE_MODELS); + const dynamic = getDynamicBaseModels(); + + if (!dynamic || !dynamic.models) { + return hardcoded.sort(); + } + + // Merge and deduplicate + const merged = new Set([...hardcoded, ...dynamic.models]); + return Array.from(merged).sort(); +} + +/** + * Get mappable base models (for UI selection) + * Excludes 'Other' value + * @returns {Array} Sorted array of base model names (excluding 'Other') + */ +export function getMappableBaseModelsDynamic() { + const merged = getMergedBaseModels(); + return merged.filter(model => model !== 'Other'); +} + +/** + * Clear dynamic base models cache + */ +export function clearDynamicBaseModels() { + dynamicBaseModels = null; + dynamicBaseModelsTimestamp = null; +} + +/** + * Check if dynamic base models cache is valid + * @returns {boolean} + */ +export function isDynamicBaseModelsCacheValid() { + if (!dynamicBaseModels || !dynamicBaseModelsTimestamp) return false; + const age = Date.now() - new Date(dynamicBaseModelsTimestamp).getTime(); + return age <= CACHE_TTL_MS; +} diff --git a/tests/services/test_civitai_base_model_service.py b/tests/services/test_civitai_base_model_service.py new file mode 100644 index 00000000..80c2bdbe --- /dev/null +++ b/tests/services/test_civitai_base_model_service.py @@ -0,0 +1,133 @@ +"""Tests for CivitaiBaseModelService.""" + +import pytest +from unittest.mock import patch + +from py.services.civitai_base_model_service import CivitaiBaseModelService + + +class TestCivitaiBaseModelService: + """Test suite for CivitaiBaseModelService.""" + + @pytest.fixture(autouse=True) + def setup_service(self): + """Create a fresh service instance for each test.""" + self.service = CivitaiBaseModelService() + # Reset cache + self.service._cache = None + self.service._cache_timestamp = None + yield + + def test_generate_abbreviation_known_models(self): + """Test abbreviation generation for known models.""" + test_cases = [ + ("SD 1.5", "SD1"), + ("SDXL 1.0", "XL"), + ("Flux.1 D", "F1D"), + ("Wan Video 2.5 T2V", "WAN"), + ("Pony V7", "PNY7"), + ("CogVideoX", "CVX"), + ("Mochi", "MCHI"), + ("Anima", "ANI"), + ] + + for model_name, expected in test_cases: + result = self.service.generate_abbreviation(model_name) + assert result == expected, ( + f"Failed for {model_name}: got {result}, expected {expected}" + ) + + def test_generate_abbreviation_unknown_models(self): + """Test abbreviation generation for unknown models.""" + result = self.service.generate_abbreviation("New Model 2.0") + assert len(result) <= 4 + assert result.isupper() + + def test_generate_abbreviation_edge_cases(self): + """Test abbreviation generation edge cases.""" + assert self.service.generate_abbreviation("") == "OTH" + assert self.service.generate_abbreviation(None) == "OTH" + + def test_cache_status_no_cache(self): + """Test cache status when no cache exists.""" + status = self.service.get_cache_status() + + assert status["has_cache"] is False + assert status["last_updated"] is None + assert status["is_expired"] is True + assert status["age_seconds"] is None + + @pytest.mark.asyncio + async def test_get_base_models_fallback(self): + """Test that fallback to hardcoded models works.""" + with patch.object(self.service, "_fetch_from_civitai", return_value=None): + result = await self.service.get_base_models() + + assert result["source"] == "fallback" + assert len(result["models"]) > 0 + assert result["hardcoded_count"] > 0 + assert result["remote_count"] == 0 + + @pytest.mark.asyncio + async def test_get_base_models_from_api(self): + """Test fetching models from API.""" + mock_models = {"SD 1.5", "SDXL 1.0", "New Model"} + + with patch.object( + self.service, "_fetch_from_civitai", return_value=mock_models + ): + result = await self.service.get_base_models() + + assert result["source"] == "api" + assert result["remote_count"] == 3 + assert "New Model" in result["models"] + + @pytest.mark.asyncio + async def test_get_base_models_uses_cache(self): + """Test that cached data is used when available and not expired.""" + # First call - populate cache + mock_models = {"SD 1.5", "SDXL 1.0"} + with patch.object( + self.service, "_fetch_from_civitai", return_value=mock_models + ): + await self.service.get_base_models() + + # Second call - should use cache + with patch.object(self.service, "_fetch_from_civitai") as mock_fetch: + result = await self.service.get_base_models() + mock_fetch.assert_not_called() + + assert result["source"] == "cache" + + @pytest.mark.asyncio + async def test_refresh_cache(self): + """Test force refresh clears cache and fetches fresh data.""" + # Populate cache + mock_models = {"SD 1.5"} + with patch.object( + self.service, "_fetch_from_civitai", return_value=mock_models + ): + await self.service.get_base_models() + + # Force refresh with different data + new_models = {"SD 1.5", "SDXL 1.0", "New Model"} + with patch.object(self.service, "_fetch_from_civitai", return_value=new_models): + result = await self.service.refresh_cache() + + assert result["source"] == "api" + assert result["remote_count"] == 3 + + def test_get_model_categories(self): + """Test model categories are returned.""" + categories = self.service.get_model_categories() + + assert "Stable Diffusion 1.x" in categories + assert "Video Models" in categories + assert "Flux Models" in categories + assert "Other Models" in categories + + # Check that video models include new additions + video_models = categories["Video Models"] + assert "CogVideoX" in video_models + assert "Mochi" in video_models + assert "Wan Video 2.5 T2V" in video_models