feat: add embedding management functionality with routes, services, and UI integration

This commit is contained in:
Will Miao
2025-07-25 21:14:56 +08:00
parent ea8a64fafc
commit c5a3af2399
10 changed files with 376 additions and 8 deletions

View File

@@ -24,7 +24,9 @@ class Config:
self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = None
self.unet_roots = None
self.embeddings_roots = None
self.base_models_roots = self._init_checkpoint_paths()
self.embeddings_roots = self._init_embedding_paths()
# Scan symbolic links during initialization
self._scan_symbolic_links()
@@ -48,6 +50,7 @@ class Config:
'loras': self.loras_roots,
'checkpoints': self.checkpoints_roots,
'unet': self.unet_roots,
'embeddings': self.embeddings_roots,
}
# Add default roots if there's only one item and key doesn't exist
@@ -83,12 +86,15 @@ class Config:
return False
def _scan_symbolic_links(self):
"""Scan all symbolic links in LoRA and Checkpoint root directories"""
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
for root in self.loras_roots:
self._scan_directory_links(root)
for root in self.base_models_roots:
self._scan_directory_links(root)
for root in self.embeddings_roots:
self._scan_directory_links(root)
def _scan_directory_links(self, root: str):
"""Recursively scan symbolic links in a directory"""
@@ -223,6 +229,36 @@ class Config:
logger.warning(f"Error initializing checkpoint paths: {e}")
return []
def _init_embedding_paths(self) -> List[str]:
"""Initialize and validate embedding paths from ComfyUI settings"""
try:
raw_paths = folder_paths.get_folder_paths("embeddings")
# Normalize and resolve symlinks, store mapping from resolved -> original
path_map = {}
for path in raw_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
# Now sort and use only the deduplicated real paths
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
if not unique_paths:
logger.warning("No valid embeddings folders found in ComfyUI configuration")
return []
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return unique_paths
except Exception as e:
logger.warning(f"Error initializing embedding paths: {e}")
return []
def get_preview_static_url(self, preview_path: str) -> str:
"""Convert local preview path to static URL"""
if not preview_path:

View File

@@ -94,21 +94,45 @@ class LoraManager:
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for each embedding root
for idx, root in enumerate(config.embeddings_roots, start=1):
preview_path = f'/embeddings_static/root{idx}/preview'
real_root = root
if root in config._path_mappings.values():
for target, link in config._path_mappings.items():
if link == root:
real_root = target
break
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for symlink target paths
link_idx = {
'lora': 1,
'checkpoint': 1
'checkpoint': 1,
'embedding': 1
}
for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets:
# Determine if this is a checkpoint or lora link based on path
# Determine if this is a checkpoint, lora, or embedding link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
link_idx["checkpoint"] += 1
elif is_embedding:
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
link_idx["embedding"] += 1
else:
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
link_idx["lora"] += 1
@@ -168,6 +192,7 @@ class LoraManager:
# Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
@@ -175,6 +200,7 @@ class LoraManager:
# Create low-priority initialization tasks
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init')
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
await ExampleImagesMigration.check_and_run_migrations()

View File

@@ -0,0 +1,105 @@
import logging
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.embedding_service import EmbeddingService
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class EmbeddingRoutes(BaseModelRoutes):
"""Embedding-specific route controller"""
def __init__(self):
"""Initialize Embedding routes with Embedding service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "embeddings.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup Embedding routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'embeddings' prefix (includes page route)
super().setup_routes(app, 'embeddings')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Embedding-specific routes"""
# Embedding-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding)
# Embedding info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info)
async def get_embedding_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific embedding by name"""
try:
name = request.match_info.get('name', '')
embedding_info = await self.service.get_model_info_by_name(name)
if embedding_info:
return web.json_response(embedding_info)
else:
return web.json_response({"error": "Embedding not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai embedding model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be TextualInversion (Embedding)
if model_type.lower() not in ['textualinversion', 'embedding']:
return web.json_response({
'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching embedding model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -0,0 +1,26 @@
import logging
from typing import List
from ..utils.models import EmbeddingMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class EmbeddingScanner(ModelScanner):
"""Service for scanning and managing embedding files"""
def __init__(self):
# Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
super().__init__(
model_type="embedding",
model_class=EmbeddingMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
def get_model_roots(self) -> List[str]:
"""Get embedding root directories"""
return config.embeddings_roots

View File

@@ -0,0 +1,51 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import EmbeddingMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class EmbeddingService(BaseModelService):
"""Embedding-specific service implementation"""
def __init__(self, scanner):
"""Initialize Embedding service
Args:
scanner: Embedding scanner instance
"""
super().__init__("embedding", scanner, EmbeddingMetadata)
async def format_response(self, embedding_data: Dict) -> Dict:
"""Format Embedding data for API response"""
return {
"model_name": embedding_data["model_name"],
"file_name": embedding_data["file_name"],
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
"base_model": embedding_data.get("base_model", ""),
"folder": embedding_data["folder"],
"sha256": embedding_data.get("sha256", ""),
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
"file_size": embedding_data.get("size", 0),
"modified": embedding_data.get("modified", ""),
"tags": embedding_data.get("tags", []),
"modelDescription": embedding_data.get("modelDescription", ""),
"from_civitai": embedding_data.get("from_civitai", True),
"notes": embedding_data.get("notes", ""),
"model_type": embedding_data.get("model_type", "embedding"),
"favorite": embedding_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}))
}
def find_duplicate_hashes(self) -> Dict:
"""Find Embeddings with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Embeddings with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -122,11 +122,13 @@ class ModelServiceFactory:
def register_default_model_types():
"""Register the default model types (LoRA and Checkpoint)"""
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..services.embedding_service import EmbeddingService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
from ..routes.embedding_routes import EmbeddingRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
@@ -134,4 +136,7 @@ def register_default_model_types():
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
logger.info("Registered default model types: lora, checkpoint")
# Register Embedding model type
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
logger.info("Registered default model types: lora, checkpoint, embedding")

View File

@@ -174,6 +174,27 @@ class ServiceRegistry:
logger.debug(f"Registered {service_name}")
return ws_manager
@classmethod
async def get_embedding_scanner(cls):
"""Get or create Embedding scanner instance"""
service_name = "embedding_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .embedding_scanner import EmbeddingScanner
scanner = await EmbeddingScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
def clear_services(cls):
"""Clear all registered services - mainly for testing"""

View File

@@ -123,7 +123,7 @@ class LoraMetadata(BaseModelMetadata):
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
@@ -158,3 +158,41 @@ class CheckpointMetadata(BaseModelMetadata):
modelDescription=description
)
@dataclass
class EmbeddingMetadata(BaseModelMetadata):
"""Represents the metadata structure for an Embedding model"""
model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
"""Create EmbeddingMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'embedding')
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type,
tags=tags,
modelDescription=description
)