feat(misc): add VAE and Upscaler model management page

This commit is contained in:
Will Miao
2026-01-31 07:28:10 +08:00
parent b86bd44c65
commit 0a340d397c
37 changed files with 1164 additions and 38 deletions

View File

@@ -9,7 +9,7 @@ from collections import OrderedDict
import uuid
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata, MiscMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, DIFFUSION_MODEL_BASE_MODELS, VALID_LORA_TYPES
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media
@@ -60,6 +60,10 @@ class DownloadManager:
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def _get_misc_scanner(self):
"""Get the misc scanner from registry"""
return await ServiceRegistry.get_misc_scanner()
async def download_from_civitai(
self,
model_id: int = None,
@@ -275,6 +279,7 @@ class DownloadManager:
lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
misc_scanner = await self._get_misc_scanner()
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_version_id):
@@ -299,6 +304,13 @@ class DownloadManager:
"error": "Model version already exists in embedding library",
}
# Check misc scanner (VAE, Upscaler)
if await misc_scanner.check_model_version_exists(model_version_id):
return {
"success": False,
"error": "Model version already exists in misc library",
}
# Use CivArchive provider directly when source is 'civarchive'
# This prioritizes CivArchive metadata (with mirror availability info) over Civitai
if source == "civarchive":
@@ -337,6 +349,10 @@ class DownloadManager:
model_type = "lora"
elif model_type_from_info == "textualinversion":
model_type = "embedding"
elif model_type_from_info == "vae":
model_type = "misc"
elif model_type_from_info == "upscaler":
model_type = "misc"
else:
return {
"success": False,
@@ -379,6 +395,14 @@ class DownloadManager:
"success": False,
"error": "Model version already exists in embedding library",
}
elif model_type == "misc":
# Check misc scanner (VAE, Upscaler)
misc_scanner = await self._get_misc_scanner()
if await misc_scanner.check_model_version_exists(version_id):
return {
"success": False,
"error": "Model version already exists in misc library",
}
# Handle use_default_paths
if use_default_paths:
@@ -413,6 +437,26 @@ class DownloadManager:
"error": "Default embedding root path not set in settings",
}
save_dir = default_path
elif model_type == "misc":
from ..config import config
civitai_type = version_info.get("model", {}).get("type", "").lower()
if civitai_type == "vae":
default_paths = config.vae_roots
error_msg = "VAE root path not configured"
elif civitai_type == "upscaler":
default_paths = config.upscaler_roots
error_msg = "Upscaler root path not configured"
else:
default_paths = config.misc_roots
error_msg = "Misc root path not configured"
if not default_paths:
return {
"success": False,
"error": error_msg,
}
save_dir = default_paths[0] if default_paths else ""
# Calculate relative path using template
relative_path = self._calculate_relative_path(version_info, model_type)
@@ -515,6 +559,11 @@ class DownloadManager:
version_info, file_info, save_path
)
logger.info(f"Creating EmbeddingMetadata for {file_name}")
elif model_type == "misc":
metadata = MiscMetadata.from_civitai_info(
version_info, file_info, save_path
)
logger.info(f"Creating MiscMetadata for {file_name}")
# 6. Start download process
result = await self._execute_download(
@@ -620,6 +669,8 @@ class DownloadManager:
scanner = await self._get_checkpoint_scanner()
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
elif model_type == "misc":
scanner = await self._get_misc_scanner()
except Exception as exc:
logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc)
@@ -1016,6 +1067,9 @@ class DownloadManager:
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
elif model_type == "misc":
scanner = await self._get_misc_scanner()
logger.info(f"Updating misc cache for {actual_file_paths[0]}")
adjust_cached_entry = (
getattr(scanner, "adjust_cached_entry", None)
@@ -1125,6 +1179,14 @@ class DownloadManager:
".pkl",
".sft",
}
if model_type == "misc":
return {
".ckpt",
".pt",
".bin",
".pth",
".safetensors",
}
return {".safetensors"}
async def _extract_model_files_from_archive(

View File

@@ -0,0 +1,55 @@
import logging
from typing import Any, Dict, List, Optional
from ..utils.models import MiscMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class MiscScanner(ModelScanner):
"""Service for scanning and managing misc files (VAE, Upscaler)"""
def __init__(self):
# Define supported file extensions (combined from VAE and upscaler)
file_extensions = {'.safetensors', '.pt', '.bin', '.ckpt', '.pth'}
super().__init__(
model_type="misc",
model_class=MiscMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path."""
if not root_path:
return None
if config.vae_roots and root_path in config.vae_roots:
return "vae"
if config.upscaler_roots and root_path in config.upscaler_roots:
return "upscaler"
return None
def adjust_metadata(self, metadata, file_path, root_path):
"""Adjust metadata during scanning to set sub_type."""
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.sub_type = sub_type
return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
"""Adjust entries loaded from the persisted cache to ensure sub_type is set."""
sub_type = self._resolve_sub_type(
self._find_root_for_file(entry.get("file_path"))
)
if sub_type:
entry["sub_type"] = sub_type
return entry
def get_model_roots(self) -> List[str]:
"""Get misc root directories (VAE and upscaler)"""
return config.misc_roots

View File

@@ -0,0 +1,55 @@
import os
import logging
from typing import Dict
from .base_model_service import BaseModelService
from ..utils.models import MiscMetadata
from ..config import config
logger = logging.getLogger(__name__)
class MiscService(BaseModelService):
"""Misc-specific service implementation (VAE, Upscaler)"""
def __init__(self, scanner, update_service=None):
"""Initialize Misc service
Args:
scanner: Misc scanner instance
update_service: Optional service for remote update tracking.
"""
super().__init__("misc", scanner, MiscMetadata, update_service=update_service)
async def format_response(self, misc_data: Dict) -> Dict:
"""Format Misc data for API response"""
# Get sub_type from cache entry (new canonical field)
sub_type = misc_data.get("sub_type", "vae")
return {
"model_name": misc_data["model_name"],
"file_name": misc_data["file_name"],
"preview_url": config.get_preview_static_url(misc_data.get("preview_url", "")),
"preview_nsfw_level": misc_data.get("preview_nsfw_level", 0),
"base_model": misc_data.get("base_model", ""),
"folder": misc_data["folder"],
"sha256": misc_data.get("sha256", ""),
"file_path": misc_data["file_path"].replace(os.sep, "/"),
"file_size": misc_data.get("size", 0),
"modified": misc_data.get("modified", ""),
"tags": misc_data.get("tags", []),
"from_civitai": misc_data.get("from_civitai", True),
"usage_count": misc_data.get("usage_count", 0),
"notes": misc_data.get("notes", ""),
"sub_type": sub_type,
"favorite": misc_data.get("favorite", False),
"update_available": bool(misc_data.get("update_available", False)),
"civitai": self.filter_civitai_data(misc_data.get("civitai", {}), minimal=True)
}
def find_duplicate_hashes(self) -> Dict:
"""Find Misc models with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Misc models with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -118,19 +118,24 @@ class ModelServiceFactory:
def register_default_model_types():
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
"""Register the default model types (LoRA, Checkpoint, Embedding, and Misc)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..services.embedding_service import EmbeddingService
from ..services.misc_service import MiscService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
from ..routes.embedding_routes import EmbeddingRoutes
from ..routes.misc_model_routes import MiscModelRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
# Register Embedding model type
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
# Register Misc model type (VAE, Upscaler)
ModelServiceFactory.register_model_type('misc', MiscService, MiscModelRoutes)

View File

@@ -233,23 +233,44 @@ class ServiceRegistry:
async def get_embedding_scanner(cls):
"""Get or create Embedding scanner instance"""
service_name = "embedding_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .embedding_scanner import EmbeddingScanner
scanner = await EmbeddingScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_misc_scanner(cls):
"""Get or create Misc scanner instance (VAE, Upscaler)"""
service_name = "misc_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .misc_scanner import MiscScanner
scanner = await MiscScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
def clear_services(cls):
"""Clear all registered services - mainly for testing"""