refactor: Remove unused service imports and add new route for scanning LoRA files

This commit is contained in:
Will Miao
2025-04-12 07:49:11 +08:00
parent 9277d8d8f8
commit a8d21fb1d6
5 changed files with 146 additions and 16 deletions

View File

@@ -6,17 +6,11 @@ from typing import Dict
from ..utils.routes_common import ModelRouteUtils from ..utils.routes_common import ModelRouteUtils
from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager
from ..services.civitai_client import CivitaiClient
from ..config import config from ..config import config
from ..services.lora_scanner import LoraScanner
from operator import itemgetter
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings from ..services.settings_manager import settings
import asyncio import asyncio
from .update_routes import UpdateRoutes from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
@@ -68,6 +62,7 @@ class ApiRoutes:
app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
# Add update check routes # Add update check routes
UpdateRoutes.setup_routes(app) UpdateRoutes.setup_routes(app)
@@ -89,6 +84,15 @@ class ApiRoutes:
if self.scanner is None: if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner() self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_replace_preview(request, self.scanner) return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def scan_loras(self, request: web.Request) -> web.Response:
"""Force a rescan of LoRA files"""
try:
await self.scanner.get_cached_data(force_refresh=True)
return web.json_response({"status": "success", "message": "LoRA scan completed"})
except Exception as e:
logger.error(f"Error in scan_loras: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_loras(self, request: web.Request) -> web.Response: async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request""" """Handle paginated LoRA data request"""

View File

@@ -7,10 +7,7 @@ import asyncio
from ..utils.routes_common import ModelRouteUtils from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner
from ..services.download_manager import DownloadManager
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import settings

View File

@@ -1,10 +1,8 @@
import os import os
from aiohttp import web from aiohttp import web
import jinja2 import jinja2
from typing import Dict, List from typing import Dict
import logging import logging
from ..services.lora_scanner import LoraScanner
from ..services.recipe_scanner import RecipeScanner
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import

View File

@@ -8,11 +8,8 @@ import json
import asyncio import asyncio
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.recipe_parsers import RecipeParserFactory from ..utils.recipe_parsers import RecipeParserFactory
from ..services.civitai_client import CivitaiClient
from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.constants import CARD_PREVIEW_WIDTH
from ..services.recipe_scanner import RecipeScanner
from ..services.lora_scanner import LoraScanner
from ..config import config from ..config import config
from ..workflow.parser import WorkflowParser from ..workflow.parser import WorkflowParser
from ..utils.utils import download_civitai_image from ..utils.utils import download_civitai_image

View File

@@ -125,7 +125,12 @@ class ModelScanner:
# If force refresh is requested, initialize the cache directly # If force refresh is requested, initialize the cache directly
if force_refresh: if force_refresh:
await self._initialize_cache() if self._cache is None:
# For initial creation, do a full initialization
await self._initialize_cache()
else:
# For subsequent refreshes, use fast reconciliation
await self._reconcile_cache()
return self._cache return self._cache
@@ -173,6 +178,135 @@ class ModelScanner:
folders=[] folders=[]
) )
async def _reconcile_cache(self) -> None:
"""Fast cache reconciliation - only process differences between cache and filesystem"""
try:
start_time = time.time()
logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...")
# Get current cached file paths
cached_paths = {item['file_path'] for item in self._cache.raw_data}
path_to_item = {item['file_path']: item for item in self._cache.raw_data}
# Track found files and new files
found_paths = set()
new_files = []
# Scan all model roots
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
# Track visited real paths to avoid symlink loops
visited_real_paths = set()
# Recursively scan directory
for root, _, files in os.walk(root_path, followlinks=True):
real_root = os.path.realpath(root)
if real_root in visited_real_paths:
continue
visited_real_paths.add(real_root)
for file in files:
ext = os.path.splitext(file)[1].lower()
if ext in self.file_extensions:
# Construct paths exactly as they would be in cache
file_path = os.path.join(root, file).replace(os.sep, '/')
# Check if this file is already in cache
if file_path in cached_paths:
found_paths.add(file_path)
continue
# Try case-insensitive match on Windows
if os.name == 'nt':
lower_path = file_path.lower()
matched = False
for cached_path in cached_paths:
if cached_path.lower() == lower_path:
found_paths.add(cached_path)
matched = True
break
if matched:
continue
# This is a new file to process
new_files.append(file_path)
# Yield control periodically
await asyncio.sleep(0)
# Process new files in batches
total_added = 0
if new_files:
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(new_files)} new files to process")
batch_size = 50
for i in range(0, len(new_files), batch_size):
batch = new_files[i:i+batch_size]
for path in batch:
try:
model_data = await self.scan_single_model(path)
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Update tags count
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
total_added += 1
except Exception as e:
logger.error(f"Error adding {path} to cache: {e}")
# Yield control after each batch
await asyncio.sleep(0)
# Find missing files (in cache but not in filesystem)
missing_files = cached_paths - found_paths
total_removed = 0
if missing_files:
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(missing_files)} files to remove from cache")
# Process files to remove
for path in missing_files:
try:
model_to_remove = path_to_item[path]
# Update tags count
for tag in model_to_remove.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove from hash index
self._hash_index.remove_by_path(path)
total_removed += 1
except Exception as e:
logger.error(f"Error removing {path} from cache: {e}")
# Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in missing_files]
# Resort cache if changes were made
if total_added > 0 or total_removed > 0:
# Update folders list
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await self._cache.resort()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
# These methods should be implemented in child classes # These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]: async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata""" """Scan all model directories and return metadata"""