From c5a3af2399146ef691329309e2a2fa253bcc3e08 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 25 Jul 2025 21:14:56 +0800 Subject: [PATCH] feat: add embedding management functionality with routes, services, and UI integration --- py/config.py | 38 +++++++++- py/lora_manager.py | 30 +++++++- py/routes/embedding_routes.py | 105 +++++++++++++++++++++++++++ py/services/embedding_scanner.py | 26 +++++++ py/services/embedding_service.py | 51 +++++++++++++ py/services/model_service_factory.py | 9 ++- py/services/service_registry.py | 21 ++++++ py/utils/models.py | 40 +++++++++- templates/checkpoints.html | 2 - templates/embeddings.html | 62 ++++++++++++++++ 10 files changed, 376 insertions(+), 8 deletions(-) create mode 100644 py/routes/embedding_routes.py create mode 100644 py/services/embedding_scanner.py create mode 100644 py/services/embedding_service.py create mode 100644 templates/embeddings.html diff --git a/py/config.py b/py/config.py index 1196ec80..b0c944c3 100644 --- a/py/config.py +++ b/py/config.py @@ -24,7 +24,9 @@ class Config: self.loras_roots = self._init_lora_paths() self.checkpoints_roots = None self.unet_roots = None + self.embeddings_roots = None self.base_models_roots = self._init_checkpoint_paths() + self.embeddings_roots = self._init_embedding_paths() # Scan symbolic links during initialization self._scan_symbolic_links() @@ -48,6 +50,7 @@ class Config: 'loras': self.loras_roots, 'checkpoints': self.checkpoints_roots, 'unet': self.unet_roots, + 'embeddings': self.embeddings_roots, } # Add default roots if there's only one item and key doesn't exist @@ -83,12 +86,15 @@ class Config: return False def _scan_symbolic_links(self): - """Scan all symbolic links in LoRA and Checkpoint root directories""" + """Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories""" for root in self.loras_roots: self._scan_directory_links(root) for root in self.base_models_roots: self._scan_directory_links(root) + + for root in self.embeddings_roots: + self._scan_directory_links(root) def _scan_directory_links(self, root: str): """Recursively scan symbolic links in a directory""" @@ -223,6 +229,36 @@ class Config: logger.warning(f"Error initializing checkpoint paths: {e}") return [] + def _init_embedding_paths(self) -> List[str]: + """Initialize and validate embedding paths from ComfyUI settings""" + try: + raw_paths = folder_paths.get_folder_paths("embeddings") + + # Normalize and resolve symlinks, store mapping from resolved -> original + path_map = {} + for path in raw_paths: + if os.path.exists(path): + real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') + path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen + + # Now sort and use only the deduplicated real paths + unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) + logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) + + if not unique_paths: + logger.warning("No valid embeddings folders found in ComfyUI configuration") + return [] + + for original_path in unique_paths: + real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + if real_path != original_path: + self.add_path_mapping(original_path, real_path) + + return unique_paths + except Exception as e: + logger.warning(f"Error initializing embedding paths: {e}") + return [] + def get_preview_static_url(self, preview_path: str) -> str: """Convert local preview path to static URL""" if not preview_path: diff --git a/py/lora_manager.py b/py/lora_manager.py index e211d75e..0c2f5e1d 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -94,21 +94,45 @@ class LoraManager: config.add_route_mapping(real_root, preview_path) added_targets.add(real_root) + # Add static routes for each embedding root + for idx, root in enumerate(config.embeddings_roots, start=1): + preview_path = f'/embeddings_static/root{idx}/preview' + + real_root = root + if root in config._path_mappings.values(): + for target, link in config._path_mappings.items(): + if link == root: + real_root = target + break + # Add static route for original path + app.router.add_static(preview_path, real_root) + logger.info(f"Added static route {preview_path} -> {real_root}") + + # Record route mapping + config.add_route_mapping(real_root, preview_path) + added_targets.add(real_root) + # Add static routes for symlink target paths link_idx = { 'lora': 1, - 'checkpoint': 1 + 'checkpoint': 1, + 'embedding': 1 } for target_path, link_path in config._path_mappings.items(): if target_path not in added_targets: - # Determine if this is a checkpoint or lora link based on path + # Determine if this is a checkpoint, lora, or embedding link based on path is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots) is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots) + is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots) + is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots) if is_checkpoint: route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview' link_idx["checkpoint"] += 1 + elif is_embedding: + route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview' + link_idx["embedding"] += 1 else: route_path = f'/loras_static/link_{link_idx["lora"]}/preview' link_idx["lora"] += 1 @@ -168,6 +192,7 @@ class LoraManager: # Initialize scanners in background lora_scanner = await ServiceRegistry.get_lora_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + embedding_scanner = await ServiceRegistry.get_embedding_scanner() # Initialize recipe scanner if needed recipe_scanner = await ServiceRegistry.get_recipe_scanner() @@ -175,6 +200,7 @@ class LoraManager: # Create low-priority initialization tasks asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init') asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init') + asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init') asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') await ExampleImagesMigration.check_and_run_migrations() diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py new file mode 100644 index 00000000..eb9a5203 --- /dev/null +++ b/py/routes/embedding_routes.py @@ -0,0 +1,105 @@ +import logging +from aiohttp import web + +from .base_model_routes import BaseModelRoutes +from ..services.embedding_service import EmbeddingService +from ..services.service_registry import ServiceRegistry + +logger = logging.getLogger(__name__) + +class EmbeddingRoutes(BaseModelRoutes): + """Embedding-specific route controller""" + + def __init__(self): + """Initialize Embedding routes with Embedding service""" + # Service will be initialized later via setup_routes + self.service = None + self.civitai_client = None + self.template_name = "embeddings.html" + + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + embedding_scanner = await ServiceRegistry.get_embedding_scanner() + self.service = EmbeddingService(embedding_scanner) + self.civitai_client = await ServiceRegistry.get_civitai_client() + + # Initialize parent with the service + super().__init__(self.service) + + def setup_routes(self, app: web.Application): + """Setup Embedding routes""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + # Setup common routes with 'embeddings' prefix (includes page route) + super().setup_routes(app, 'embeddings') + + def setup_specific_routes(self, app: web.Application, prefix: str): + """Setup Embedding-specific routes""" + # Embedding-specific CivitAI integration + app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding) + + # Embedding info by name + app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info) + + async def get_embedding_info(self, request: web.Request) -> web.Response: + """Get detailed information for a specific embedding by name""" + try: + name = request.match_info.get('name', '') + embedding_info = await self.service.get_model_info_by_name(name) + + if embedding_info: + return web.json_response(embedding_info) + else: + return web.json_response({"error": "Embedding not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_embedding_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai embedding model with local availability info""" + try: + model_id = request.match_info['model_id'] + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be TextualInversion (Embedding) + if model_type.lower() not in ['textualinversion', 'embedding']: + return web.json_response({ + 'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the primary model file (type="Model" and primary=true) in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model' and file.get('primary') == True), None) + + # If no primary file found, try to find any model file + if not model_file: + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.service.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.service.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching embedding model versions: {e}") + return web.Response(status=500, text=str(e)) diff --git a/py/services/embedding_scanner.py b/py/services/embedding_scanner.py new file mode 100644 index 00000000..89257e1e --- /dev/null +++ b/py/services/embedding_scanner.py @@ -0,0 +1,26 @@ +import logging +from typing import List + +from ..utils.models import EmbeddingMetadata +from ..config import config +from .model_scanner import ModelScanner +from .model_hash_index import ModelHashIndex + +logger = logging.getLogger(__name__) + +class EmbeddingScanner(ModelScanner): + """Service for scanning and managing embedding files""" + + def __init__(self): + # Define supported file extensions + file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} + super().__init__( + model_type="embedding", + model_class=EmbeddingMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) + + def get_model_roots(self) -> List[str]: + """Get embedding root directories""" + return config.embeddings_roots diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py new file mode 100644 index 00000000..b63559d2 --- /dev/null +++ b/py/services/embedding_service.py @@ -0,0 +1,51 @@ +import os +import logging +from typing import Dict, List, Optional + +from .base_model_service import BaseModelService +from ..utils.models import EmbeddingMetadata +from ..config import config +from ..utils.routes_common import ModelRouteUtils + +logger = logging.getLogger(__name__) + +class EmbeddingService(BaseModelService): + """Embedding-specific service implementation""" + + def __init__(self, scanner): + """Initialize Embedding service + + Args: + scanner: Embedding scanner instance + """ + super().__init__("embedding", scanner, EmbeddingMetadata) + + async def format_response(self, embedding_data: Dict) -> Dict: + """Format Embedding data for API response""" + return { + "model_name": embedding_data["model_name"], + "file_name": embedding_data["file_name"], + "preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")), + "preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0), + "base_model": embedding_data.get("base_model", ""), + "folder": embedding_data["folder"], + "sha256": embedding_data.get("sha256", ""), + "file_path": embedding_data["file_path"].replace(os.sep, "/"), + "file_size": embedding_data.get("size", 0), + "modified": embedding_data.get("modified", ""), + "tags": embedding_data.get("tags", []), + "modelDescription": embedding_data.get("modelDescription", ""), + "from_civitai": embedding_data.get("from_civitai", True), + "notes": embedding_data.get("notes", ""), + "model_type": embedding_data.get("model_type", "embedding"), + "favorite": embedding_data.get("favorite", False), + "civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {})) + } + + def find_duplicate_hashes(self) -> Dict: + """Find Embeddings with duplicate SHA256 hashes""" + return self.scanner._hash_index.get_duplicate_hashes() + + def find_duplicate_filenames(self) -> Dict: + """Find Embeddings with conflicting filenames""" + return self.scanner._hash_index.get_duplicate_filenames() diff --git a/py/services/model_service_factory.py b/py/services/model_service_factory.py index 6cc8a3a3..4d655eed 100644 --- a/py/services/model_service_factory.py +++ b/py/services/model_service_factory.py @@ -122,11 +122,13 @@ class ModelServiceFactory: def register_default_model_types(): - """Register the default model types (LoRA and Checkpoint)""" + """Register the default model types (LoRA, Checkpoint, and Embedding)""" from ..services.lora_service import LoraService from ..services.checkpoint_service import CheckpointService + from ..services.embedding_service import EmbeddingService from ..routes.lora_routes import LoraRoutes from ..routes.checkpoint_routes import CheckpointRoutes + from ..routes.embedding_routes import EmbeddingRoutes # Register LoRA model type ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes) @@ -134,4 +136,7 @@ def register_default_model_types(): # Register Checkpoint model type ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes) - logger.info("Registered default model types: lora, checkpoint") \ No newline at end of file + # Register Embedding model type + ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes) + + logger.info("Registered default model types: lora, checkpoint, embedding") \ No newline at end of file diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 6cefb4d4..541d3026 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -174,6 +174,27 @@ class ServiceRegistry: logger.debug(f"Registered {service_name}") return ws_manager + @classmethod + async def get_embedding_scanner(cls): + """Get or create Embedding scanner instance""" + service_name = "embedding_scanner" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + # Double-check after acquiring lock + if service_name in cls._services: + return cls._services[service_name] + + # Import here to avoid circular imports + from .embedding_scanner import EmbeddingScanner + + scanner = await EmbeddingScanner.get_instance() + cls._services[service_name] = scanner + logger.debug(f"Created and registered {service_name}") + return scanner + @classmethod def clear_services(cls): """Clear all registered services - mainly for testing""" diff --git a/py/utils/models.py b/py/utils/models.py index 3947e671..ac2650d7 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -123,7 +123,7 @@ class LoraMetadata(BaseModelMetadata): @dataclass class CheckpointMetadata(BaseModelMetadata): """Represents the metadata structure for a Checkpoint model""" - model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) + model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.) @classmethod def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata': @@ -158,3 +158,41 @@ class CheckpointMetadata(BaseModelMetadata): modelDescription=description ) +@dataclass +class EmbeddingMetadata(BaseModelMetadata): + """Represents the metadata structure for an Embedding model""" + model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.) + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata': + """Create EmbeddingMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + model_type = version_info.get('type', 'embedding') + + # Extract tags and description if available + tags = [] + description = "" + if 'model' in version_info: + if 'tags' in version_info['model']: + tags = version_info['model']['tags'] + if 'description' in version_info['model']: + description = version_info['model']['description'] + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, + from_civitai=True, + civitai=version_info, + model_type=model_type, + tags=tags, + modelDescription=description + ) + diff --git a/templates/checkpoints.html b/templates/checkpoints.html index f65803ca..2e00b7a1 100644 --- a/templates/checkpoints.html +++ b/templates/checkpoints.html @@ -14,8 +14,6 @@ {% block additional_components %}