From a8d21fb1d652453e7de19880ce62c7a01c978bfd Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 12 Apr 2025 07:49:11 +0800 Subject: [PATCH] refactor: Remove unused service imports and add new route for scanning LoRA files --- py/routes/api_routes.py | 16 ++-- py/routes/checkpoints_routes.py | 3 - py/routes/lora_routes.py | 4 +- py/routes/recipe_routes.py | 3 - py/services/model_scanner.py | 136 +++++++++++++++++++++++++++++++- 5 files changed, 146 insertions(+), 16 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index b7506bb8..00b74bd2 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -6,17 +6,11 @@ from typing import Dict from ..utils.routes_common import ModelRouteUtils -from ..services.file_monitor import LoraFileMonitor -from ..services.download_manager import DownloadManager -from ..services.civitai_client import CivitaiClient from ..config import config -from ..services.lora_scanner import LoraScanner -from operator import itemgetter from ..services.websocket_manager import ws_manager from ..services.settings_manager import settings import asyncio from .update_routes import UpdateRoutes -from ..services.recipe_scanner import RecipeScanner from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils from ..services.service_registry import ServiceRegistry @@ -68,6 +62,7 @@ class ApiRoutes: app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files + app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files # Add update check routes UpdateRoutes.setup_routes(app) @@ -89,6 +84,15 @@ class ApiRoutes: if self.scanner is None: self.scanner = await ServiceRegistry.get_lora_scanner() return await ModelRouteUtils.handle_replace_preview(request, self.scanner) + + async def scan_loras(self, request: web.Request) -> web.Response: + """Force a rescan of LoRA files""" + try: + await self.scanner.get_cached_data(force_refresh=True) + return web.json_response({"status": "success", "message": "LoRA scan completed"}) + except Exception as e: + logger.error(f"Error in scan_loras: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) async def get_loras(self, request: web.Request) -> web.Response: """Handle paginated LoRA data request""" diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index e50ff042..53c94a84 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -7,10 +7,7 @@ import asyncio from ..utils.routes_common import ModelRouteUtils from ..utils.constants import NSFW_LEVELS -from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager -from ..services.checkpoint_scanner import CheckpointScanner -from ..services.download_manager import DownloadManager from ..services.service_registry import ServiceRegistry from ..config import config from ..services.settings_manager import settings diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 91e1d2eb..4db5bddb 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,10 +1,8 @@ import os from aiohttp import web import jinja2 -from typing import Dict, List +from typing import Dict import logging -from ..services.lora_scanner import LoraScanner -from ..services.recipe_scanner import RecipeScanner from ..config import config from ..services.settings_manager import settings from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 76dd7467..78ba27cb 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -8,11 +8,8 @@ import json import asyncio from ..utils.exif_utils import ExifUtils from ..utils.recipe_parsers import RecipeParserFactory -from ..services.civitai_client import CivitaiClient from ..utils.constants import CARD_PREVIEW_WIDTH -from ..services.recipe_scanner import RecipeScanner -from ..services.lora_scanner import LoraScanner from ..config import config from ..workflow.parser import WorkflowParser from ..utils.utils import download_civitai_image diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index edba07f6..0ada59b9 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -125,7 +125,12 @@ class ModelScanner: # If force refresh is requested, initialize the cache directly if force_refresh: - await self._initialize_cache() + if self._cache is None: + # For initial creation, do a full initialization + await self._initialize_cache() + else: + # For subsequent refreshes, use fast reconciliation + await self._reconcile_cache() return self._cache @@ -173,6 +178,135 @@ class ModelScanner: folders=[] ) + async def _reconcile_cache(self) -> None: + """Fast cache reconciliation - only process differences between cache and filesystem""" + try: + start_time = time.time() + logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...") + + # Get current cached file paths + cached_paths = {item['file_path'] for item in self._cache.raw_data} + path_to_item = {item['file_path']: item for item in self._cache.raw_data} + + # Track found files and new files + found_paths = set() + new_files = [] + + # Scan all model roots + for root_path in self.get_model_roots(): + if not os.path.exists(root_path): + continue + + # Track visited real paths to avoid symlink loops + visited_real_paths = set() + + # Recursively scan directory + for root, _, files in os.walk(root_path, followlinks=True): + real_root = os.path.realpath(root) + if real_root in visited_real_paths: + continue + visited_real_paths.add(real_root) + + for file in files: + ext = os.path.splitext(file)[1].lower() + if ext in self.file_extensions: + # Construct paths exactly as they would be in cache + file_path = os.path.join(root, file).replace(os.sep, '/') + + # Check if this file is already in cache + if file_path in cached_paths: + found_paths.add(file_path) + continue + + # Try case-insensitive match on Windows + if os.name == 'nt': + lower_path = file_path.lower() + matched = False + for cached_path in cached_paths: + if cached_path.lower() == lower_path: + found_paths.add(cached_path) + matched = True + break + if matched: + continue + + # This is a new file to process + new_files.append(file_path) + + # Yield control periodically + await asyncio.sleep(0) + + # Process new files in batches + total_added = 0 + if new_files: + logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(new_files)} new files to process") + batch_size = 50 + for i in range(0, len(new_files), batch_size): + batch = new_files[i:i+batch_size] + for path in batch: + try: + model_data = await self.scan_single_model(path) + if model_data: + # Add to cache + self._cache.raw_data.append(model_data) + + # Update hash index if available + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Update tags count + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + total_added += 1 + except Exception as e: + logger.error(f"Error adding {path} to cache: {e}") + + # Yield control after each batch + await asyncio.sleep(0) + + # Find missing files (in cache but not in filesystem) + missing_files = cached_paths - found_paths + total_removed = 0 + + if missing_files: + logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(missing_files)} files to remove from cache") + + # Process files to remove + for path in missing_files: + try: + model_to_remove = path_to_item[path] + + # Update tags count + for tag in model_to_remove.get('tags', []): + if tag in self._tags_count: + self._tags_count[tag] = max(0, self._tags_count[tag] - 1) + if self._tags_count[tag] == 0: + del self._tags_count[tag] + + # Remove from hash index + self._hash_index.remove_by_path(path) + total_removed += 1 + except Exception as e: + logger.error(f"Error removing {path} from cache: {e}") + + # Update cache data + self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in missing_files] + + # Resort cache if changes were made + if total_added > 0 or total_removed > 0: + # Update folders list + all_folders = set(item.get('folder', '') for item in self._cache.raw_data) + self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + # Resort cache + await self._cache.resort() + + logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True) + # These methods should be implemented in child classes async def scan_all_models(self) -> List[Dict]: """Scan all model directories and return metadata"""