mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat(misc): add VAE and Upscaler model management page
This commit is contained in:
54
py/config.py
54
py/config.py
@@ -89,8 +89,11 @@ class Config:
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
self.embeddings_roots = None
|
||||
self.vae_roots = None
|
||||
self.upscaler_roots = None
|
||||
self.base_models_roots = self._init_checkpoint_paths()
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
self.misc_roots = self._init_misc_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._initialize_symlink_mappings()
|
||||
|
||||
@@ -151,6 +154,8 @@ class Config:
|
||||
'checkpoints': list(self.checkpoints_roots or []),
|
||||
'unet': list(self.unet_roots or []),
|
||||
'embeddings': list(self.embeddings_roots or []),
|
||||
'vae': list(self.vae_roots or []),
|
||||
'upscale_models': list(self.upscaler_roots or []),
|
||||
}
|
||||
|
||||
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths)
|
||||
@@ -250,6 +255,7 @@ class Config:
|
||||
roots.extend(self.loras_roots or [])
|
||||
roots.extend(self.base_models_roots or [])
|
||||
roots.extend(self.embeddings_roots or [])
|
||||
roots.extend(self.misc_roots or [])
|
||||
return roots
|
||||
|
||||
def _build_symlink_fingerprint(self) -> Dict[str, object]:
|
||||
@@ -599,6 +605,8 @@ class Config:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.embeddings_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.misc_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
|
||||
for target, link in self._path_mappings.items():
|
||||
preview_roots.update(self._expand_preview_root(target))
|
||||
@@ -606,11 +614,12 @@ class Config:
|
||||
|
||||
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()}
|
||||
logger.debug(
|
||||
"Preview roots rebuilt: %d paths from %d lora roots, %d checkpoint roots, %d embedding roots, %d symlink mappings",
|
||||
"Preview roots rebuilt: %d paths from %d lora roots, %d checkpoint roots, %d embedding roots, %d misc roots, %d symlink mappings",
|
||||
len(self._preview_root_paths),
|
||||
len(self.loras_roots or []),
|
||||
len(self.base_models_roots or []),
|
||||
len(self.embeddings_roots or []),
|
||||
len(self.misc_roots or []),
|
||||
len(self._path_mappings),
|
||||
)
|
||||
|
||||
@@ -769,6 +778,49 @@ class Config:
|
||||
logger.warning(f"Error initializing embedding paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_misc_paths(self) -> List[str]:
|
||||
"""Initialize and validate misc (VAE and upscaler) paths from ComfyUI settings"""
|
||||
try:
|
||||
raw_vae_paths = folder_paths.get_folder_paths("vae")
|
||||
raw_upscaler_paths = folder_paths.get_folder_paths("upscale_models")
|
||||
unique_paths = self._prepare_misc_paths(raw_vae_paths, raw_upscaler_paths)
|
||||
|
||||
logger.info("Found misc roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid VAE or upscaler folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing misc paths: {e}")
|
||||
return []
|
||||
|
||||
def _prepare_misc_paths(
|
||||
self, vae_paths: Iterable[str], upscaler_paths: Iterable[str]
|
||||
) -> List[str]:
|
||||
vae_map = self._dedupe_existing_paths(vae_paths)
|
||||
upscaler_map = self._dedupe_existing_paths(upscaler_paths)
|
||||
|
||||
merged_map: Dict[str, str] = {}
|
||||
for real_path, original in {**vae_map, **upscaler_map}.items():
|
||||
if real_path not in merged_map:
|
||||
merged_map[real_path] = original
|
||||
|
||||
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
|
||||
|
||||
vae_values = set(vae_map.values())
|
||||
upscaler_values = set(upscaler_map.values())
|
||||
self.vae_roots = [p for p in unique_paths if p in vae_values]
|
||||
self.upscaler_roots = [p for p in unique_paths if p in upscaler_values]
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
@@ -184,15 +184,17 @@ class LoraManager:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
misc_scanner = await ServiceRegistry.get_misc_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
init_tasks = [
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
|
||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
|
||||
asyncio.create_task(misc_scanner.initialize_in_background(), name='misc_cache_init'),
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
]
|
||||
|
||||
@@ -252,8 +254,9 @@ class LoraManager:
|
||||
# Collect all model roots
|
||||
all_roots = set()
|
||||
all_roots.update(config.loras_roots)
|
||||
all_roots.update(config.base_models_roots)
|
||||
all_roots.update(config.base_models_roots)
|
||||
all_roots.update(config.embeddings_roots)
|
||||
all_roots.update(config.misc_roots or [])
|
||||
|
||||
total_deleted = 0
|
||||
total_size_freed = 0
|
||||
|
||||
112
py/routes/misc_model_routes.py
Normal file
112
py/routes/misc_model_routes.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
from typing import Dict
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.misc_service import MiscService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MiscModelRoutes(BaseModelRoutes):
|
||||
"""Misc-specific route controller (VAE, Upscaler)"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Misc routes with Misc service"""
|
||||
super().__init__()
|
||||
self.template_name = "misc.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
misc_scanner = await ServiceRegistry.get_misc_scanner()
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.service = MiscService(misc_scanner, update_service=update_service)
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Misc routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'misc' prefix (includes page route)
|
||||
super().setup_routes(app, 'misc')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Misc-specific routes"""
|
||||
# Misc info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_misc_info)
|
||||
|
||||
# VAE roots and Upscaler roots
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/vae_roots', prefix, self.get_vae_roots)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/upscaler_roots', prefix, self.get_upscaler_roots)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Misc (VAE or Upscaler)"""
|
||||
return model_type.lower() in ['vae', 'upscaler']
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "VAE or Upscaler"
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse Misc-specific parameters"""
|
||||
params: Dict = {}
|
||||
|
||||
if 'misc_hash' in request.query:
|
||||
params['hash_filters'] = {'single_hash': request.query['misc_hash'].lower()}
|
||||
elif 'misc_hashes' in request.query:
|
||||
params['hash_filters'] = {
|
||||
'multiple_hashes': [h.lower() for h in request.query['misc_hashes'].split(',')]
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
async def get_misc_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific misc model by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
misc_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if misc_info:
|
||||
return web.json_response(misc_info)
|
||||
else:
|
||||
return web.json_response({"error": "Misc model not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_misc_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_vae_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of VAE roots from config"""
|
||||
try:
|
||||
roots = config.vae_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting VAE roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_upscaler_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of upscaler roots from config"""
|
||||
try:
|
||||
roots = config.upscaler_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting upscaler roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
@@ -9,7 +9,7 @@ from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata, MiscMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, DIFFUSION_MODEL_BASE_MODELS, VALID_LORA_TYPES
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
from ..utils.preview_selection import select_preview_media
|
||||
@@ -60,6 +60,10 @@ class DownloadManager:
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def _get_misc_scanner(self):
|
||||
"""Get the misc scanner from registry"""
|
||||
return await ServiceRegistry.get_misc_scanner()
|
||||
|
||||
async def download_from_civitai(
|
||||
self,
|
||||
model_id: int = None,
|
||||
@@ -275,6 +279,7 @@ class DownloadManager:
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
misc_scanner = await self._get_misc_scanner()
|
||||
|
||||
# Check lora scanner first
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
@@ -299,6 +304,13 @@ class DownloadManager:
|
||||
"error": "Model version already exists in embedding library",
|
||||
}
|
||||
|
||||
# Check misc scanner (VAE, Upscaler)
|
||||
if await misc_scanner.check_model_version_exists(model_version_id):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Model version already exists in misc library",
|
||||
}
|
||||
|
||||
# Use CivArchive provider directly when source is 'civarchive'
|
||||
# This prioritizes CivArchive metadata (with mirror availability info) over Civitai
|
||||
if source == "civarchive":
|
||||
@@ -337,6 +349,10 @@ class DownloadManager:
|
||||
model_type = "lora"
|
||||
elif model_type_from_info == "textualinversion":
|
||||
model_type = "embedding"
|
||||
elif model_type_from_info == "vae":
|
||||
model_type = "misc"
|
||||
elif model_type_from_info == "upscaler":
|
||||
model_type = "misc"
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -379,6 +395,14 @@ class DownloadManager:
|
||||
"success": False,
|
||||
"error": "Model version already exists in embedding library",
|
||||
}
|
||||
elif model_type == "misc":
|
||||
# Check misc scanner (VAE, Upscaler)
|
||||
misc_scanner = await self._get_misc_scanner()
|
||||
if await misc_scanner.check_model_version_exists(version_id):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Model version already exists in misc library",
|
||||
}
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
@@ -413,6 +437,26 @@ class DownloadManager:
|
||||
"error": "Default embedding root path not set in settings",
|
||||
}
|
||||
save_dir = default_path
|
||||
elif model_type == "misc":
|
||||
from ..config import config
|
||||
|
||||
civitai_type = version_info.get("model", {}).get("type", "").lower()
|
||||
if civitai_type == "vae":
|
||||
default_paths = config.vae_roots
|
||||
error_msg = "VAE root path not configured"
|
||||
elif civitai_type == "upscaler":
|
||||
default_paths = config.upscaler_roots
|
||||
error_msg = "Upscaler root path not configured"
|
||||
else:
|
||||
default_paths = config.misc_roots
|
||||
error_msg = "Misc root path not configured"
|
||||
|
||||
if not default_paths:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
}
|
||||
save_dir = default_paths[0] if default_paths else ""
|
||||
|
||||
# Calculate relative path using template
|
||||
relative_path = self._calculate_relative_path(version_info, model_type)
|
||||
@@ -515,6 +559,11 @@ class DownloadManager:
|
||||
version_info, file_info, save_path
|
||||
)
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
elif model_type == "misc":
|
||||
metadata = MiscMetadata.from_civitai_info(
|
||||
version_info, file_info, save_path
|
||||
)
|
||||
logger.info(f"Creating MiscMetadata for {file_name}")
|
||||
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
@@ -620,6 +669,8 @@ class DownloadManager:
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
elif model_type == "embedding":
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
elif model_type == "misc":
|
||||
scanner = await self._get_misc_scanner()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc)
|
||||
|
||||
@@ -1016,6 +1067,9 @@ class DownloadManager:
|
||||
elif model_type == "embedding":
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
|
||||
elif model_type == "misc":
|
||||
scanner = await self._get_misc_scanner()
|
||||
logger.info(f"Updating misc cache for {actual_file_paths[0]}")
|
||||
|
||||
adjust_cached_entry = (
|
||||
getattr(scanner, "adjust_cached_entry", None)
|
||||
@@ -1125,6 +1179,14 @@ class DownloadManager:
|
||||
".pkl",
|
||||
".sft",
|
||||
}
|
||||
if model_type == "misc":
|
||||
return {
|
||||
".ckpt",
|
||||
".pt",
|
||||
".bin",
|
||||
".pth",
|
||||
".safetensors",
|
||||
}
|
||||
return {".safetensors"}
|
||||
|
||||
async def _extract_model_files_from_archive(
|
||||
|
||||
55
py/services/misc_scanner.py
Normal file
55
py/services/misc_scanner.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.models import MiscMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MiscScanner(ModelScanner):
|
||||
"""Service for scanning and managing misc files (VAE, Upscaler)"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions (combined from VAE and upscaler)
|
||||
file_extensions = {'.safetensors', '.pt', '.bin', '.ckpt', '.pth'}
|
||||
super().__init__(
|
||||
model_type="misc",
|
||||
model_class=MiscMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
"""Resolve the sub-type based on the root path."""
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
if config.vae_roots and root_path in config.vae_roots:
|
||||
return "vae"
|
||||
|
||||
if config.upscaler_roots and root_path in config.upscaler_roots:
|
||||
return "upscaler"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
"""Adjust metadata during scanning to set sub_type."""
|
||||
sub_type = self._resolve_sub_type(root_path)
|
||||
if sub_type:
|
||||
metadata.sub_type = sub_type
|
||||
return metadata
|
||||
|
||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Adjust entries loaded from the persisted cache to ensure sub_type is set."""
|
||||
sub_type = self._resolve_sub_type(
|
||||
self._find_root_for_file(entry.get("file_path"))
|
||||
)
|
||||
if sub_type:
|
||||
entry["sub_type"] = sub_type
|
||||
return entry
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get misc root directories (VAE and upscaler)"""
|
||||
return config.misc_roots
|
||||
55
py/services/misc_service.py
Normal file
55
py/services/misc_service.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import MiscMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MiscService(BaseModelService):
|
||||
"""Misc-specific service implementation (VAE, Upscaler)"""
|
||||
|
||||
def __init__(self, scanner, update_service=None):
|
||||
"""Initialize Misc service
|
||||
|
||||
Args:
|
||||
scanner: Misc scanner instance
|
||||
update_service: Optional service for remote update tracking.
|
||||
"""
|
||||
super().__init__("misc", scanner, MiscMetadata, update_service=update_service)
|
||||
|
||||
async def format_response(self, misc_data: Dict) -> Dict:
|
||||
"""Format Misc data for API response"""
|
||||
# Get sub_type from cache entry (new canonical field)
|
||||
sub_type = misc_data.get("sub_type", "vae")
|
||||
|
||||
return {
|
||||
"model_name": misc_data["model_name"],
|
||||
"file_name": misc_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(misc_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": misc_data.get("preview_nsfw_level", 0),
|
||||
"base_model": misc_data.get("base_model", ""),
|
||||
"folder": misc_data["folder"],
|
||||
"sha256": misc_data.get("sha256", ""),
|
||||
"file_path": misc_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": misc_data.get("size", 0),
|
||||
"modified": misc_data.get("modified", ""),
|
||||
"tags": misc_data.get("tags", []),
|
||||
"from_civitai": misc_data.get("from_civitai", True),
|
||||
"usage_count": misc_data.get("usage_count", 0),
|
||||
"notes": misc_data.get("notes", ""),
|
||||
"sub_type": sub_type,
|
||||
"favorite": misc_data.get("favorite", False),
|
||||
"update_available": bool(misc_data.get("update_available", False)),
|
||||
"civitai": self.filter_civitai_data(misc_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Misc models with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Misc models with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -118,19 +118,24 @@ class ModelServiceFactory:
|
||||
|
||||
|
||||
def register_default_model_types():
|
||||
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
|
||||
"""Register the default model types (LoRA, Checkpoint, Embedding, and Misc)"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..services.misc_service import MiscService
|
||||
from ..routes.lora_routes import LoraRoutes
|
||||
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||
from ..routes.embedding_routes import EmbeddingRoutes
|
||||
|
||||
from ..routes.misc_model_routes import MiscModelRoutes
|
||||
|
||||
# Register LoRA model type
|
||||
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||
|
||||
|
||||
# Register Checkpoint model type
|
||||
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||
|
||||
|
||||
# Register Embedding model type
|
||||
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
|
||||
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
|
||||
|
||||
# Register Misc model type (VAE, Upscaler)
|
||||
ModelServiceFactory.register_model_type('misc', MiscService, MiscModelRoutes)
|
||||
@@ -233,23 +233,44 @@ class ServiceRegistry:
|
||||
async def get_embedding_scanner(cls):
|
||||
"""Get or create Embedding scanner instance"""
|
||||
service_name = "embedding_scanner"
|
||||
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .embedding_scanner import EmbeddingScanner
|
||||
|
||||
|
||||
scanner = await EmbeddingScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get_misc_scanner(cls):
|
||||
"""Get or create Misc scanner instance (VAE, Upscaler)"""
|
||||
service_name = "misc_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .misc_scanner import MiscScanner
|
||||
|
||||
scanner = await MiscScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
def clear_services(cls):
|
||||
"""Clear all registered services - mainly for testing"""
|
||||
|
||||
@@ -49,6 +49,7 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
||||
VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"]
|
||||
VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"]
|
||||
VALID_EMBEDDING_SUB_TYPES = ["embedding"]
|
||||
VALID_MISC_SUB_TYPES = ["vae", "upscaler"]
|
||||
|
||||
# Backward compatibility alias
|
||||
VALID_LORA_TYPES = VALID_LORA_SUB_TYPES
|
||||
@@ -94,6 +95,7 @@ DEFAULT_PRIORITY_TAG_CONFIG = {
|
||||
"lora": ", ".join(CIVITAI_MODEL_TAGS),
|
||||
"checkpoint": ", ".join(CIVITAI_MODEL_TAGS),
|
||||
"embedding": ", ".join(CIVITAI_MODEL_TAGS),
|
||||
"misc": ", ".join(CIVITAI_MODEL_TAGS),
|
||||
}
|
||||
|
||||
# baseModel values from CivitAI that should be treated as diffusion models (unet)
|
||||
|
||||
@@ -219,7 +219,7 @@ class EmbeddingMetadata(BaseModelMetadata):
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
sub_type = version_info.get('type', 'embedding')
|
||||
|
||||
|
||||
# Extract tags and description if available
|
||||
tags = []
|
||||
description = ""
|
||||
@@ -228,7 +228,53 @@ class EmbeddingMetadata(BaseModelMetadata):
|
||||
tags = version_info['model']['tags']
|
||||
if 'description' in version_info['model']:
|
||||
description = version_info['model']['description']
|
||||
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0,
|
||||
from_civitai=True,
|
||||
civitai=version_info,
|
||||
sub_type=sub_type,
|
||||
tags=tags,
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MiscMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for a Misc model (VAE, Upscaler)"""
|
||||
sub_type: str = "vae"
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'MiscMetadata':
|
||||
"""Create MiscMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
|
||||
# Determine sub_type from CivitAI model type
|
||||
civitai_type = version_info.get('model', {}).get('type', '').lower()
|
||||
if civitai_type == 'vae':
|
||||
sub_type = 'vae'
|
||||
elif civitai_type == 'upscaler':
|
||||
sub_type = 'upscaler'
|
||||
else:
|
||||
sub_type = 'vae' # Default to vae
|
||||
|
||||
# Extract tags and description if available
|
||||
tags = []
|
||||
description = ""
|
||||
if 'model' in version_info:
|
||||
if 'tags' in version_info['model']:
|
||||
tags = version_info['model']['tags']
|
||||
if 'description' in version_info['model']:
|
||||
description = version_info['model']['description']
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
|
||||
Reference in New Issue
Block a user