diff --git a/__init__.py b/__init__.py index d6d8021e..94b9611d 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,4 @@ +from .lora_manager import LoraManager from .nodes.lora_gateway import LoRAGateway NODE_CLASS_MAPPINGS = { @@ -10,4 +11,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { WEB_DIRECTORY = "./js" +# Register routes on import +LoraManager.add_routes() __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY'] \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 00000000..684a4f05 --- /dev/null +++ b/config.py @@ -0,0 +1,37 @@ +import os +import folder_paths # type: ignore +from typing import List + +class Config: + """Global configuration for LoRA Manager""" + + def __init__(self): + self.loras_roots = self._init_lora_paths() + self.templates_path = os.path.join(os.path.dirname(__file__), 'templates') + self.static_path = os.path.join(os.path.dirname(__file__), 'static') + + def _init_lora_paths(self) -> List[str]: + """Initialize and validate LoRA paths from ComfyUI settings""" + paths = [path.replace(os.sep, "/") + for path in folder_paths.get_folder_paths("loras") + if os.path.exists(path)] + + if not paths: + raise ValueError("No valid loras folders found in ComfyUI configuration") + + return paths + + def get_preview_static_url(self, preview_path: str) -> str: + """Convert local preview path to static URL""" + if not preview_path: + return "" + + for idx, root in enumerate(self.loras_roots, start=1): + if preview_path.startswith(root): + relative_path = os.path.relpath(preview_path, root) + return f'/loras_static/root{idx}/preview/{relative_path.replace(os.sep, "/")}' + + return "" + +# Global config instance +config = Config() diff --git a/lora_manager.py b/lora_manager.py index b1741314..a27511fd 100644 --- a/lora_manager.py +++ b/lora_manager.py @@ -1,379 +1,24 @@ -import os -import json -import time -from pathlib import Path -from aiohttp import web -from server import PromptServer -import jinja2 -from safetensors import safe_open -from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata -from .utils.lora_metadata import extract_lora_metadata -from typing import Dict, Optional -from .services.civitai_client import CivitaiClient -import folder_paths -import logging - -class LorasEndpoint: - def __init__(self): - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader( - os.path.join(os.path.dirname(__file__), 'templates') - ), - autoescape=True - ) - # Configure Loras root directories (from ComfyUI folder paths settings) - self.loras_roots = [path.replace(os.sep, "/") for path in folder_paths.get_folder_paths("loras") if os.path.exists(path)] - if not self.loras_roots: - raise ValueError("No valid loras folders found") - print(f"Loras roots: {self.loras_roots}") # debug log - - self.server = PromptServer.instance +from server import PromptServer # type: ignore +from .config import config +from .routes.lora_routes import LoraRoutes +from .routes.api_routes import ApiRoutes +class LoraManager: + """Main entry point for LoRA Manager plugin""" + @classmethod def add_routes(cls): - instance = cls() + """Initialize and register all routes""" app = PromptServer.instance.app - static_path = os.path.join(os.path.dirname(__file__), 'static') - - # Generate multiple static paths based on the number of folders in instance.loras_roots - for idx, root in enumerate(instance.loras_roots, start=1): - # Create different static paths for each folder, like /loras_static/root1/preview + + # Add static routes for each lora root + for idx, root in enumerate(config.loras_roots, start=1): preview_path = f'/loras_static/root{idx}/preview' - app.add_routes([web.static(preview_path, root)]) - - app.add_routes([ - web.get('/loras', instance.handle_loras_request), - # web.static('/loras_static/previews', instance.loras_root), - web.static('/loras_static', static_path), - web.post('/api/delete_model', instance.delete_model), - web.post('/api/fetch-civitai', instance.fetch_civitai), - web.post('/api/replace_preview', instance.replace_preview), - ]) - - - async def scan_loras(self): - loras = [] - for loras_root in self.loras_roots: - for root, _, files in os.walk(loras_root): - safetensors_files = [f for f in files if f.endswith('.safetensors')] - - for filename in safetensors_files: - file_path = os.path.join(root, filename).replace(os.sep, "/") - - # Try to load existing metadata first - metadata = await load_metadata(file_path) - - if metadata is None: - # Only get file info and extract metadata if no existing metadata - metadata = await get_file_info(file_path) - base_model_info = await extract_lora_metadata(file_path) - metadata.base_model = base_model_info['base_model'] - await save_metadata(file_path, metadata) - - # Convert to dict for API response - lora_data = metadata.to_dict() - # Get relative path and remove filename to get just the folder structure - rel_path = os.path.relpath(file_path, loras_root) - folder = os.path.dirname(rel_path) - # Ensure forward slashes for consistency across platforms - lora_data['folder'] = folder.replace(os.path.sep, '/') - - loras.append(lora_data) - - return loras - - - def clean_description(self, desc): - """清理HTML格式的描述""" - return desc.replace("

", "").replace("

", "\n").strip() - - async def handle_loras_request(self, request): - """处理Loras请求并渲染模板""" - try: - scan_start = time.time() - data = await self.scan_loras() - print(f"Lora Manager: Scanned {len(data)} loras in {time.time()-scan_start:.2f}s") - - # Format the data for the template - formatted_loras = [self.format_lora(l) for l in data] - folders = sorted(list(set(l['folder'] for l in data))) - - context = { - "loras": formatted_loras, - "folders": folders - } - - template = self.template_env.get_template('loras.html') - rendered = template.render(**context) - return web.Response( - text=rendered, - content_type='text/html' - ) - except Exception as e: - print(f"Error handling loras request: {str(e)}") - import traceback - print(traceback.format_exc()) # Print full stack trace - return web.Response( - text="Error loading loras page", - content_type='text/html', - status=500 - ) + app.router.add_static(preview_path, root) - def filter_civitai_data(self, civitai_data): - if not civitai_data: - return {} - - required_fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - - return {k: civitai_data[k] for k in required_fields if k in civitai_data} - - def format_lora(self, lora): - """格式化前端需要的数据结构""" - return { - "model_name": lora["model_name"], - "file_name": lora["file_name"], - "preview_url": self.get_static_url_for_preview(lora["preview_url"]), - "base_model": lora["base_model"], - "folder": lora["folder"], - "sha256": lora["sha256"], - "file_path": lora["file_path"].replace(os.sep, "/"), - "modified": lora["modified"], - "from_civitai": lora.get("from_civitai", True), - "civitai": self.filter_civitai_data(lora.get("civitai", {})) - } - - def get_static_url_for_preview(self, preview_url): - """ - Determines which loras_root the preview_url belongs to and - returns the corresponding static URL. - """ - for idx, root in enumerate(self.loras_roots, start=1): - # Check if preview_url belongs to current root - if preview_url.startswith(root): - # Get relative path and generate static URL - relative_path = os.path.relpath(preview_url, root) - static_url = f'/loras_static/root{idx}/preview/{relative_path.replace(os.sep, "/")}' - return static_url + # Add static route for plugin assets + app.router.add_static('/loras_static', config.static_path) - # If no matching root found, return empty string - return "" - - - async def delete_model(self, request): - try: - data = await request.json() - file_path = data.get('file_path') # 从请求中获取file_path信息 - if not file_path: - return web.Response(text='Model full path is required', status=400) - - # 构建完整的目录路径 - target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - - # List of file patterns to delete - required_file = f"{file_name}.safetensors" # 主文件必须存在 - optional_files = [ # 这些文件可能不存在 - f"{file_name}.metadata.json", - f"{file_name}.preview.png", - f"{file_name}.preview.jpg", - f"{file_name}.preview.jpeg", - f"{file_name}.preview.webp", - f"{file_name}.preview.mp4", - f"{file_name}.png", - f"{file_name}.jpg", - f"{file_name}.jpeg", - f"{file_name}.webp", - f"{file_name}.mp4" - ] - - deleted_files = [] - - # Try to delete the main safetensors file - main_file_path = os.path.join(target_dir, required_file) - if os.path.exists(main_file_path): - try: - os.remove(main_file_path) - deleted_files.append(required_file) - except Exception as e: - print(f"Error deleting {main_file_path}: {str(e)}") - return web.Response(text=f"Failed to delete main model file: {str(e)}", status=500) - - # Only try to delete optional files if main file was deleted - for pattern in optional_files: - file_path = os.path.join(target_dir, pattern) - if os.path.exists(file_path): - try: - os.remove(file_path) - deleted_files.append(pattern) - except Exception as e: - print(f"Error deleting optional file {file_path}: {str(e)}") - else: - return web.Response(text=f"Model file {required_file} not found in {target_dir}", status=404) - - return web.json_response({ - 'success': True, - 'deleted_files': deleted_files - }) - - except Exception as e: - return web.Response(text=str(e), status=500) - - async def update_civitai_info(self, file_path: str, civitai_data: Dict, preview_url: Optional[str] = None): - """Update Civitai metadata and download preview image""" - # Update metadata file - await update_civitai_metadata(file_path, civitai_data) - - # Download and save preview image if URL is provided - if preview_url: - preview_path = f"{os.path.splitext(file_path)[0]}.preview.png" - try: - # Add your image download logic here - # Example: - # await download_image(preview_url, preview_path) - pass - except Exception as e: - print(f"Error downloading preview image: {str(e)}") - - async def fetch_civitai(self, request): - try: - data = await request.json() - client = CivitaiClient() - - try: - metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' - - local_metadata = {} - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - local_metadata = json.load(f) - - - if not local_metadata.get('from_civitai', True): - return web.json_response( - {"success": True, "Notice": "Not from CivitAI"}, - status=200 - ) - - # 1. 获取CivitAI元数据 - civitai_metadata = await client.get_model_by_hash(data["sha256"]) - if not civitai_metadata: - local_metadata['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return web.json_response( - {"success": False, "error": "Not found on CivitAI"}, - status=404 - ) - - local_metadata['civitai']=civitai_metadata - - # 更新模型名称(优先使用CivitAI名称) - if 'model' in civitai_metadata: - local_metadata['model_name'] = civitai_metadata['model'].get('name', local_metadata.get('model_name')) - # update base model - local_metadata['base_model'] = civitai_metadata.get('baseModel') - - # 4. 下载预览图 - # Check if existing preview is valid and the file exists - if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): - first_preview = next((img for img in civitai_metadata.get('images', [])), None) - if first_preview: - - preview_extension = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] # Get the file extension - preview_filename = os.path.splitext(os.path.basename(data['file_path']))[0] + '.preview' + preview_extension - preview_path = os.path.join(os.path.dirname(data['file_path']), preview_filename) - await client.download_preview_image(first_preview['url'], preview_path) - # 存储相对路径,使用正斜杠格式 - local_metadata['preview_url'] = preview_path.replace(os.sep, '/') - - # 5. 保存更新后的元数据 - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - - return web.json_response({ - "success": True - }) - - except Exception as e: - print(f"Error in fetch_civitai: {str(e)}") # Debug log - return web.json_response({ - "success": False, - "error": str(e) - }, status=500) - finally: - await client.close() - - except Exception as e: - print(f"Error processing request: {str(e)}") # Debug log - return web.json_response({ - "success": False, - "error": f"Request processing error: {str(e)}" - }, status=400) - - async def replace_preview(self, request): - try: - reader = await request.multipart() - - # Get the preview_file field first - file_field = await reader.next() - if file_field.name != 'preview_file': - raise ValueError("Expected 'preview_file' field first") - preview_data = await file_field.read() - - # Get the file model_path field - name_field = await reader.next() - if name_field.name != 'model_path': - raise ValueError("Expected 'model_path' field second") - model_path = (await name_field.read()).decode() - - # Get the content type from the file field headers - content_type = file_field.headers.get('Content-Type', '') - - print(f"Received preview file: {model_path} ({content_type})") # Debug log - - # Determine file extension based on content type - if content_type.startswith('video/'): - extension = '.preview.mp4' - else: - extension = '.preview.png' - - # Construct the preview file path - base_name = os.path.splitext(os.path.basename(model_path))[0] # Remove original extension - preview_name = base_name + extension - # Get the folder path from the model_path - folder = os.path.dirname(model_path) - preview_path = os.path.join(folder, preview_name).replace(os.sep, '/') - - # Save the preview file - with open(preview_path, 'wb') as f: - f.write(preview_data) - - # Update metadata if it exists - metadata_path = os.path.join(folder, base_name + '.metadata.json') - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - # Update the preview_url to match the new file name - metadata['preview_url'] = preview_path - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - except Exception as e: - print(f"Error updating metadata: {str(e)}") - # Continue even if metadata update fails - - return web.json_response({ - "success": True, - "preview_url": self.get_static_url_for_preview(preview_path) - }) - - except Exception as e: - print(f"Error replacing preview: {str(e)}") - return web.Response(text=str(e), status=500) - -# 注册路由 -LorasEndpoint.add_routes() \ No newline at end of file + # Setup feature routes + LoraRoutes.setup_routes(app) + ApiRoutes.setup_routes(app) \ No newline at end of file diff --git a/nodes/lora_gateway.py b/nodes/lora_gateway.py index 970e944c..94a80d93 100644 --- a/nodes/lora_gateway.py +++ b/nodes/lora_gateway.py @@ -1,6 +1,3 @@ -from ..lora_manager import LorasEndpoint - - class LoRAGateway: """ LoRA Gateway Node @@ -15,10 +12,4 @@ class LoRAGateway: RETURN_TYPES = () FUNCTION = "register_services" - CATEGORY = "LoRA Management" - - @classmethod - def register_services(cls): - # Service registration logic - LorasEndpoint.add_routes() - return () \ No newline at end of file + CATEGORY = "LoRA Management" \ No newline at end of file diff --git a/routes/api_routes.py b/routes/api_routes.py new file mode 100644 index 00000000..79b84370 --- /dev/null +++ b/routes/api_routes.py @@ -0,0 +1,228 @@ +import os +import json +import logging +from aiohttp import web +from typing import Dict, List +from ..services.civitai_client import CivitaiClient +from ..utils.file_utils import update_civitai_metadata, load_metadata +from ..config import config + +logger = logging.getLogger(__name__) + +class ApiRoutes: + """API route handlers for LoRA management""" + + @classmethod + def setup_routes(cls, app: web.Application): + """Register API routes""" + routes = cls() + app.router.add_post('/api/delete_model', routes.delete_model) + app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) + app.router.add_post('/api/replace_preview', routes.replace_preview) + + async def delete_model(self, request: web.Request) -> web.Response: + """Handle model deletion request""" + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='Model path is required', status=400) + + target_dir = os.path.dirname(file_path) + file_name = os.path.splitext(os.path.basename(file_path))[0] + + deleted_files = await self._delete_model_files(target_dir, file_name) + + return web.json_response({ + 'success': True, + 'deleted_files': deleted_files + }) + + except Exception as e: + logger.error(f"Error deleting model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata fetch request""" + client = CivitaiClient() + try: + data = await request.json() + metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' + + # Check if model is from CivitAI + local_metadata = await self._load_local_metadata(metadata_path) + if not local_metadata.get('from_civitai', True): + return web.json_response({"success": True, "notice": "Not from CivitAI"}) + + # Fetch and update metadata + civitai_metadata = await client.get_model_by_hash(data["sha256"]) + if not civitai_metadata: + return await self._handle_not_found_on_civitai(metadata_path, local_metadata) + + await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, client) + + return web.json_response({"success": True}) + + except Exception as e: + logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) + return web.json_response({"success": False, "error": str(e)}, status=500) + finally: + await client.close() + + async def replace_preview(self, request: web.Request) -> web.Response: + """Handle preview image replacement request""" + try: + reader = await request.multipart() + preview_data, content_type = await self._read_preview_file(reader) + model_path = await self._read_model_path(reader) + + preview_path = await self._save_preview_file(model_path, preview_data, content_type) + await self._update_preview_metadata(model_path, preview_path) + + return web.json_response({ + "success": True, + "preview_url": config.get_preview_static_url(preview_path) + }) + + except Exception as e: + logger.error(f"Error replacing preview: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + # Private helper methods + async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]: + """Delete model and associated files""" + patterns = [ + f"{file_name}.safetensors", # Required + f"{file_name}.metadata.json", + f"{file_name}.preview.png", + f"{file_name}.preview.jpg", + f"{file_name}.preview.jpeg", + f"{file_name}.preview.webp", + f"{file_name}.preview.mp4", + f"{file_name}.png", + f"{file_name}.jpg", + f"{file_name}.jpeg", + f"{file_name}.webp", + f"{file_name}.mp4" + ] + + deleted = [] + main_file = patterns[0] + main_path = os.path.join(target_dir, main_file) + + if not os.path.exists(main_path): + raise web.HTTPNotFound(text=f"Model file not found: {main_file}") + + # Delete main file first + os.remove(main_path) + deleted.append(main_file) + + # Delete optional files + for pattern in patterns[1:]: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + try: + os.remove(path) + deleted.append(pattern) + except Exception as e: + logger.warning(f"Failed to delete {pattern}: {e}") + + return deleted + + async def _read_preview_file(self, reader) -> tuple[bytes, str]: + """Read preview file and content type from multipart request""" + field = await reader.next() + if field.name != 'preview_file': + raise ValueError("Expected 'preview_file' field") + content_type = field.headers.get('Content-Type', 'image/png') + return await field.read(), content_type + + async def _read_model_path(self, reader) -> str: + """Read model path from multipart request""" + field = await reader.next() + if field.name != 'model_path': + raise ValueError("Expected 'model_path' field") + return (await field.read()).decode() + + async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str: + """Save preview file and return its path""" + # Determine file extension based on content type + if content_type.startswith('video/'): + extension = '.preview.mp4' + else: + extension = '.preview.png' + + base_name = os.path.splitext(os.path.basename(model_path))[0] + folder = os.path.dirname(model_path) + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/') + + with open(preview_path, 'wb') as f: + f.write(preview_data) + + return preview_path + + async def _update_preview_metadata(self, model_path: str, preview_path: str): + """Update preview path in metadata""" + metadata_path = os.path.splitext(model_path)[0] + '.metadata.json' + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Update preview_url directly in the metadata dict + metadata['preview_url'] = preview_path + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Error updating metadata: {e}") + + async def _load_local_metadata(self, metadata_path: str) -> Dict: + """Load local metadata file""" + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Error loading metadata from {metadata_path}: {e}") + return {} + + async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response: + """Handle case when model is not found on CivitAI""" + local_metadata['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + return web.json_response( + {"success": False, "error": "Not found on CivitAI"}, + status=404 + ) + + async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, + civitai_metadata: Dict, client: CivitaiClient) -> None: + """Update local metadata with CivitAI data""" + local_metadata['civitai'] = civitai_metadata + + # Update model name if available + if 'model' in civitai_metadata: + local_metadata['model_name'] = civitai_metadata['model'].get('name', + local_metadata.get('model_name')) + + # Update base model + local_metadata['base_model'] = civitai_metadata.get('baseModel') + + # Update preview if needed + if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): + first_preview = next((img for img in civitai_metadata.get('images', [])), None) + if first_preview: + preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] + # Fix: Get base name without .metadata.json + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_filename = base_name + '.preview' + preview_ext + preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) + + if await client.download_preview_image(first_preview['url'], preview_path): + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) diff --git a/routes/lora_routes.py b/routes/lora_routes.py new file mode 100644 index 00000000..cd2158e4 --- /dev/null +++ b/routes/lora_routes.py @@ -0,0 +1,81 @@ +import os +from aiohttp import web +import jinja2 +from typing import Dict, List +import logging +from ..services.lora_scanner import LoraScanner +from ..config import config + +logger = logging.getLogger(__name__) + +class LoraRoutes: + """Route handlers for LoRA management endpoints""" + + def __init__(self): + self.scanner = LoraScanner() + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) + + def format_lora_data(self, lora: Dict) -> Dict: + """Format LoRA data for template rendering""" + return { + "model_name": lora["model_name"], + "file_name": lora["file_name"], + "preview_url": config.get_preview_static_url(lora["preview_url"]), + "base_model": lora["base_model"], + "folder": lora["folder"], + "sha256": lora["sha256"], + "file_path": lora["file_path"].replace(os.sep, "/"), + "modified": lora["modified"], + "from_civitai": lora.get("from_civitai", True), + "civitai": self._filter_civitai_data(lora.get("civitai", {})) + } + + def _filter_civitai_data(self, data: Dict) -> Dict: + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images" + ] + return {k: data[k] for k in fields if k in data} + + async def handle_loras_page(self, request: web.Request) -> web.Response: + """Handle GET /loras request""" + try: + # Scan for loras + loras = await self.scanner.scan_all_loras() + + # Format data for template + formatted_loras = [self.format_lora_data(l) for l in loras] + folders = sorted(list(set(l['folder'] for l in loras))) + + # Render template + template = self.template_env.get_template('loras.html') + rendered = template.render( + loras=formatted_loras, + folders=folders + ) + + return web.Response( + text=rendered, + content_type='text/html' + ) + + except Exception as e: + logger.error(f"Error handling loras request: {e}", exc_info=True) + return web.Response( + text="Error loading loras page", + status=500 + ) + + @classmethod + def setup_routes(cls, app: web.Application): + """Register routes with the application""" + routes = cls() + app.router.add_get('/loras', routes.handle_loras_page) diff --git a/services/lora_scanner.py b/services/lora_scanner.py new file mode 100644 index 00000000..a4575743 --- /dev/null +++ b/services/lora_scanner.py @@ -0,0 +1,60 @@ +import os +import logging +from typing import List, Dict +from ..config import config +from ..utils.file_utils import load_metadata, get_file_info, save_metadata +from ..utils.lora_metadata import extract_lora_metadata + +logger = logging.getLogger(__name__) + +class LoraScanner: + """Service for scanning and managing LoRA files""" + + async def scan_all_loras(self) -> List[Dict]: + """Scan all LoRA directories and return metadata""" + all_loras = [] + + for loras_root in config.loras_roots: + try: + loras = await self._scan_directory(loras_root) + all_loras.extend(loras) + except Exception as e: + logger.error(f"Error scanning directory {loras_root}: {e}") + + return all_loras + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for LoRA files""" + loras = [] + + for root, _, files in os.walk(root_path): + for filename in (f for f in files if f.endswith('.safetensors')): + try: + file_path = os.path.join(root, filename).replace(os.sep, "/") + lora_data = await self._process_lora_file(file_path, root_path) + if lora_data: + loras.append(lora_data) + except Exception as e: + logger.error(f"Error processing {filename}: {e}") + + return loras + + async def _process_lora_file(self, file_path: str, root_path: str) -> Dict: + """Process a single LoRA file and return its metadata""" + # Try loading existing metadata + metadata = await load_metadata(file_path) + + if metadata is None: + # Create new metadata if none exists + metadata = await get_file_info(file_path) + base_model_info = await extract_lora_metadata(file_path) + metadata.base_model = base_model_info['base_model'] + await save_metadata(file_path, metadata) + + # Convert to dict and add folder info + lora_data = metadata.to_dict() + rel_path = os.path.relpath(file_path, root_path) + folder = os.path.dirname(rel_path) + lora_data['folder'] = folder.replace(os.path.sep, '/') + + return lora_data diff --git a/utils/file_utils.py b/utils/file_utils.py index c0ef5906..e2cc9380 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -28,7 +28,7 @@ def _find_preview_file(base_name: str, dir_path: str) -> str: for pattern in preview_patterns: full_pattern = os.path.join(dir_path, pattern) if os.path.exists(full_pattern): - return full_pattern.replace("\\", "/") + return full_pattern.replace(os.sep, "/") return "" async def get_file_info(file_path: str) -> LoraMetadata: