feat(misc): add VAE and Upscaler model management page

This commit is contained in:
Will Miao
2026-01-31 07:28:10 +08:00
parent b86bd44c65
commit 0a340d397c
37 changed files with 1164 additions and 38 deletions

View File

@@ -89,8 +89,11 @@ class Config:
self.checkpoints_roots = None
self.unet_roots = None
self.embeddings_roots = None
self.vae_roots = None
self.upscaler_roots = None
self.base_models_roots = self._init_checkpoint_paths()
self.embeddings_roots = self._init_embedding_paths()
self.misc_roots = self._init_misc_paths()
# Scan symbolic links during initialization
self._initialize_symlink_mappings()
@@ -151,6 +154,8 @@ class Config:
'checkpoints': list(self.checkpoints_roots or []),
'unet': list(self.unet_roots or []),
'embeddings': list(self.embeddings_roots or []),
'vae': list(self.vae_roots or []),
'upscale_models': list(self.upscaler_roots or []),
}
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths)
@@ -250,6 +255,7 @@ class Config:
roots.extend(self.loras_roots or [])
roots.extend(self.base_models_roots or [])
roots.extend(self.embeddings_roots or [])
roots.extend(self.misc_roots or [])
return roots
def _build_symlink_fingerprint(self) -> Dict[str, object]:
@@ -599,6 +605,8 @@ class Config:
preview_roots.update(self._expand_preview_root(root))
for root in self.embeddings_roots or []:
preview_roots.update(self._expand_preview_root(root))
for root in self.misc_roots or []:
preview_roots.update(self._expand_preview_root(root))
for target, link in self._path_mappings.items():
preview_roots.update(self._expand_preview_root(target))
@@ -606,11 +614,12 @@ class Config:
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()}
logger.debug(
"Preview roots rebuilt: %d paths from %d lora roots, %d checkpoint roots, %d embedding roots, %d symlink mappings",
"Preview roots rebuilt: %d paths from %d lora roots, %d checkpoint roots, %d embedding roots, %d misc roots, %d symlink mappings",
len(self._preview_root_paths),
len(self.loras_roots or []),
len(self.base_models_roots or []),
len(self.embeddings_roots or []),
len(self.misc_roots or []),
len(self._path_mappings),
)
@@ -769,6 +778,49 @@ class Config:
logger.warning(f"Error initializing embedding paths: {e}")
return []
def _init_misc_paths(self) -> List[str]:
"""Initialize and validate misc (VAE and upscaler) paths from ComfyUI settings"""
try:
raw_vae_paths = folder_paths.get_folder_paths("vae")
raw_upscaler_paths = folder_paths.get_folder_paths("upscale_models")
unique_paths = self._prepare_misc_paths(raw_vae_paths, raw_upscaler_paths)
logger.info("Found misc roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
if not unique_paths:
logger.warning("No valid VAE or upscaler folders found in ComfyUI configuration")
return []
return unique_paths
except Exception as e:
logger.warning(f"Error initializing misc paths: {e}")
return []
def _prepare_misc_paths(
self, vae_paths: Iterable[str], upscaler_paths: Iterable[str]
) -> List[str]:
vae_map = self._dedupe_existing_paths(vae_paths)
upscaler_map = self._dedupe_existing_paths(upscaler_paths)
merged_map: Dict[str, str] = {}
for real_path, original in {**vae_map, **upscaler_map}.items():
if real_path not in merged_map:
merged_map[real_path] = original
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
vae_values = set(vae_map.values())
upscaler_values = set(upscaler_map.values())
self.vae_roots = [p for p in unique_paths if p in vae_values]
self.upscaler_roots = [p for p in unique_paths if p in upscaler_values]
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return unique_paths
def get_preview_static_url(self, preview_path: str) -> str:
if not preview_path:
return ""

View File

@@ -184,15 +184,17 @@ class LoraManager:
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
misc_scanner = await ServiceRegistry.get_misc_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks
init_tasks = [
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
asyncio.create_task(misc_scanner.initialize_in_background(), name='misc_cache_init'),
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
]
@@ -252,8 +254,9 @@ class LoraManager:
# Collect all model roots
all_roots = set()
all_roots.update(config.loras_roots)
all_roots.update(config.base_models_roots)
all_roots.update(config.base_models_roots)
all_roots.update(config.embeddings_roots)
all_roots.update(config.misc_roots or [])
total_deleted = 0
total_size_freed = 0

View File

@@ -0,0 +1,112 @@
import logging
from typing import Dict
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from .model_route_registrar import ModelRouteRegistrar
from ..services.misc_service import MiscService
from ..services.service_registry import ServiceRegistry
from ..config import config
logger = logging.getLogger(__name__)
class MiscModelRoutes(BaseModelRoutes):
"""Misc-specific route controller (VAE, Upscaler)"""
def __init__(self):
"""Initialize Misc routes with Misc service"""
super().__init__()
self.template_name = "misc.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
misc_scanner = await ServiceRegistry.get_misc_scanner()
update_service = await ServiceRegistry.get_model_update_service()
self.service = MiscService(misc_scanner, update_service=update_service)
self.set_model_update_service(update_service)
# Attach service dependencies
self.attach_service(self.service)
def setup_routes(self, app: web.Application):
"""Setup Misc routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'misc' prefix (includes page route)
super().setup_routes(app, 'misc')
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
"""Setup Misc-specific routes"""
# Misc info by name
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_misc_info)
# VAE roots and Upscaler roots
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/vae_roots', prefix, self.get_vae_roots)
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/upscaler_roots', prefix, self.get_upscaler_roots)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for Misc (VAE or Upscaler)"""
return model_type.lower() in ['vae', 'upscaler']
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "VAE or Upscaler"
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse Misc-specific parameters"""
params: Dict = {}
if 'misc_hash' in request.query:
params['hash_filters'] = {'single_hash': request.query['misc_hash'].lower()}
elif 'misc_hashes' in request.query:
params['hash_filters'] = {
'multiple_hashes': [h.lower() for h in request.query['misc_hashes'].split(',')]
}
return params
async def get_misc_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific misc model by name"""
try:
name = request.match_info.get('name', '')
misc_info = await self.service.get_model_info_by_name(name)
if misc_info:
return web.json_response(misc_info)
else:
return web.json_response({"error": "Misc model not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_misc_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_vae_roots(self, request: web.Request) -> web.Response:
"""Return the list of VAE roots from config"""
try:
roots = config.vae_roots
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting VAE roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_upscaler_roots(self, request: web.Request) -> web.Response:
"""Return the list of upscaler roots from config"""
try:
roots = config.upscaler_roots
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting upscaler roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -9,7 +9,7 @@ from collections import OrderedDict
import uuid
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata, MiscMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, DIFFUSION_MODEL_BASE_MODELS, VALID_LORA_TYPES
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media
@@ -60,6 +60,10 @@ class DownloadManager:
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def _get_misc_scanner(self):
"""Get the misc scanner from registry"""
return await ServiceRegistry.get_misc_scanner()
async def download_from_civitai(
self,
model_id: int = None,
@@ -275,6 +279,7 @@ class DownloadManager:
lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
misc_scanner = await self._get_misc_scanner()
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_version_id):
@@ -299,6 +304,13 @@ class DownloadManager:
"error": "Model version already exists in embedding library",
}
# Check misc scanner (VAE, Upscaler)
if await misc_scanner.check_model_version_exists(model_version_id):
return {
"success": False,
"error": "Model version already exists in misc library",
}
# Use CivArchive provider directly when source is 'civarchive'
# This prioritizes CivArchive metadata (with mirror availability info) over Civitai
if source == "civarchive":
@@ -337,6 +349,10 @@ class DownloadManager:
model_type = "lora"
elif model_type_from_info == "textualinversion":
model_type = "embedding"
elif model_type_from_info == "vae":
model_type = "misc"
elif model_type_from_info == "upscaler":
model_type = "misc"
else:
return {
"success": False,
@@ -379,6 +395,14 @@ class DownloadManager:
"success": False,
"error": "Model version already exists in embedding library",
}
elif model_type == "misc":
# Check misc scanner (VAE, Upscaler)
misc_scanner = await self._get_misc_scanner()
if await misc_scanner.check_model_version_exists(version_id):
return {
"success": False,
"error": "Model version already exists in misc library",
}
# Handle use_default_paths
if use_default_paths:
@@ -413,6 +437,26 @@ class DownloadManager:
"error": "Default embedding root path not set in settings",
}
save_dir = default_path
elif model_type == "misc":
from ..config import config
civitai_type = version_info.get("model", {}).get("type", "").lower()
if civitai_type == "vae":
default_paths = config.vae_roots
error_msg = "VAE root path not configured"
elif civitai_type == "upscaler":
default_paths = config.upscaler_roots
error_msg = "Upscaler root path not configured"
else:
default_paths = config.misc_roots
error_msg = "Misc root path not configured"
if not default_paths:
return {
"success": False,
"error": error_msg,
}
save_dir = default_paths[0] if default_paths else ""
# Calculate relative path using template
relative_path = self._calculate_relative_path(version_info, model_type)
@@ -515,6 +559,11 @@ class DownloadManager:
version_info, file_info, save_path
)
logger.info(f"Creating EmbeddingMetadata for {file_name}")
elif model_type == "misc":
metadata = MiscMetadata.from_civitai_info(
version_info, file_info, save_path
)
logger.info(f"Creating MiscMetadata for {file_name}")
# 6. Start download process
result = await self._execute_download(
@@ -620,6 +669,8 @@ class DownloadManager:
scanner = await self._get_checkpoint_scanner()
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
elif model_type == "misc":
scanner = await self._get_misc_scanner()
except Exception as exc:
logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc)
@@ -1016,6 +1067,9 @@ class DownloadManager:
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
elif model_type == "misc":
scanner = await self._get_misc_scanner()
logger.info(f"Updating misc cache for {actual_file_paths[0]}")
adjust_cached_entry = (
getattr(scanner, "adjust_cached_entry", None)
@@ -1125,6 +1179,14 @@ class DownloadManager:
".pkl",
".sft",
}
if model_type == "misc":
return {
".ckpt",
".pt",
".bin",
".pth",
".safetensors",
}
return {".safetensors"}
async def _extract_model_files_from_archive(

View File

@@ -0,0 +1,55 @@
import logging
from typing import Any, Dict, List, Optional
from ..utils.models import MiscMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class MiscScanner(ModelScanner):
"""Service for scanning and managing misc files (VAE, Upscaler)"""
def __init__(self):
# Define supported file extensions (combined from VAE and upscaler)
file_extensions = {'.safetensors', '.pt', '.bin', '.ckpt', '.pth'}
super().__init__(
model_type="misc",
model_class=MiscMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path."""
if not root_path:
return None
if config.vae_roots and root_path in config.vae_roots:
return "vae"
if config.upscaler_roots and root_path in config.upscaler_roots:
return "upscaler"
return None
def adjust_metadata(self, metadata, file_path, root_path):
"""Adjust metadata during scanning to set sub_type."""
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.sub_type = sub_type
return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
"""Adjust entries loaded from the persisted cache to ensure sub_type is set."""
sub_type = self._resolve_sub_type(
self._find_root_for_file(entry.get("file_path"))
)
if sub_type:
entry["sub_type"] = sub_type
return entry
def get_model_roots(self) -> List[str]:
"""Get misc root directories (VAE and upscaler)"""
return config.misc_roots

View File

@@ -0,0 +1,55 @@
import os
import logging
from typing import Dict
from .base_model_service import BaseModelService
from ..utils.models import MiscMetadata
from ..config import config
logger = logging.getLogger(__name__)
class MiscService(BaseModelService):
"""Misc-specific service implementation (VAE, Upscaler)"""
def __init__(self, scanner, update_service=None):
"""Initialize Misc service
Args:
scanner: Misc scanner instance
update_service: Optional service for remote update tracking.
"""
super().__init__("misc", scanner, MiscMetadata, update_service=update_service)
async def format_response(self, misc_data: Dict) -> Dict:
"""Format Misc data for API response"""
# Get sub_type from cache entry (new canonical field)
sub_type = misc_data.get("sub_type", "vae")
return {
"model_name": misc_data["model_name"],
"file_name": misc_data["file_name"],
"preview_url": config.get_preview_static_url(misc_data.get("preview_url", "")),
"preview_nsfw_level": misc_data.get("preview_nsfw_level", 0),
"base_model": misc_data.get("base_model", ""),
"folder": misc_data["folder"],
"sha256": misc_data.get("sha256", ""),
"file_path": misc_data["file_path"].replace(os.sep, "/"),
"file_size": misc_data.get("size", 0),
"modified": misc_data.get("modified", ""),
"tags": misc_data.get("tags", []),
"from_civitai": misc_data.get("from_civitai", True),
"usage_count": misc_data.get("usage_count", 0),
"notes": misc_data.get("notes", ""),
"sub_type": sub_type,
"favorite": misc_data.get("favorite", False),
"update_available": bool(misc_data.get("update_available", False)),
"civitai": self.filter_civitai_data(misc_data.get("civitai", {}), minimal=True)
}
def find_duplicate_hashes(self) -> Dict:
"""Find Misc models with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Misc models with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -118,19 +118,24 @@ class ModelServiceFactory:
def register_default_model_types():
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
"""Register the default model types (LoRA, Checkpoint, Embedding, and Misc)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..services.embedding_service import EmbeddingService
from ..services.misc_service import MiscService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
from ..routes.embedding_routes import EmbeddingRoutes
from ..routes.misc_model_routes import MiscModelRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
# Register Embedding model type
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
# Register Misc model type (VAE, Upscaler)
ModelServiceFactory.register_model_type('misc', MiscService, MiscModelRoutes)

View File

@@ -233,23 +233,44 @@ class ServiceRegistry:
async def get_embedding_scanner(cls):
"""Get or create Embedding scanner instance"""
service_name = "embedding_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .embedding_scanner import EmbeddingScanner
scanner = await EmbeddingScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_misc_scanner(cls):
"""Get or create Misc scanner instance (VAE, Upscaler)"""
service_name = "misc_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .misc_scanner import MiscScanner
scanner = await MiscScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
def clear_services(cls):
"""Clear all registered services - mainly for testing"""

View File

@@ -49,6 +49,7 @@ SUPPORTED_MEDIA_EXTENSIONS = {
VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"]
VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"]
VALID_EMBEDDING_SUB_TYPES = ["embedding"]
VALID_MISC_SUB_TYPES = ["vae", "upscaler"]
# Backward compatibility alias
VALID_LORA_TYPES = VALID_LORA_SUB_TYPES
@@ -94,6 +95,7 @@ DEFAULT_PRIORITY_TAG_CONFIG = {
"lora": ", ".join(CIVITAI_MODEL_TAGS),
"checkpoint": ", ".join(CIVITAI_MODEL_TAGS),
"embedding": ", ".join(CIVITAI_MODEL_TAGS),
"misc": ", ".join(CIVITAI_MODEL_TAGS),
}
# baseModel values from CivitAI that should be treated as diffusion models (unet)

View File

@@ -219,7 +219,7 @@ class EmbeddingMetadata(BaseModelMetadata):
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
sub_type = version_info.get('type', 'embedding')
# Extract tags and description if available
tags = []
description = ""
@@ -228,7 +228,53 @@ class EmbeddingMetadata(BaseModelMetadata):
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
sub_type=sub_type,
tags=tags,
modelDescription=description
)
@dataclass
class MiscMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Misc model (VAE, Upscaler)"""
sub_type: str = "vae"
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'MiscMetadata':
"""Create MiscMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
# Determine sub_type from CivitAI model type
civitai_type = version_info.get('model', {}).get('type', '').lower()
if civitai_type == 'vae':
sub_type = 'vae'
elif civitai_type == 'upscaler':
sub_type = 'upscaler'
else:
sub_type = 'vae' # Default to vae
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),