mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor: Remove unused service imports and add new route for scanning LoRA files
This commit is contained in:
@@ -6,17 +6,11 @@ from typing import Dict
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
from ..services.file_monitor import LoraFileMonitor
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..config import config
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from operator import itemgetter
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.settings_manager import settings
|
||||
import asyncio
|
||||
from .update_routes import UpdateRoutes
|
||||
from ..services.recipe_scanner import RecipeScanner
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
@@ -68,6 +62,7 @@ class ApiRoutes:
|
||||
app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models
|
||||
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
|
||||
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
||||
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
||||
|
||||
# Add update check routes
|
||||
UpdateRoutes.setup_routes(app)
|
||||
@@ -89,6 +84,15 @@ class ApiRoutes:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def scan_loras(self, request: web.Request) -> web.Response:
|
||||
"""Force a rescan of LoRA files"""
|
||||
try:
|
||||
await self.scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"status": "success", "message": "LoRA scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_loras: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_loras(self, request: web.Request) -> web.Response:
|
||||
"""Handle paginated LoRA data request"""
|
||||
|
||||
@@ -7,10 +7,7 @@ import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
import logging
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..services.recipe_scanner import RecipeScanner
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
|
||||
@@ -8,11 +8,8 @@ import json
|
||||
import asyncio
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.recipe_parsers import RecipeParserFactory
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
|
||||
from ..services.recipe_scanner import RecipeScanner
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
from ..workflow.parser import WorkflowParser
|
||||
from ..utils.utils import download_civitai_image
|
||||
|
||||
@@ -125,7 +125,12 @@ class ModelScanner:
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
await self._initialize_cache()
|
||||
if self._cache is None:
|
||||
# For initial creation, do a full initialization
|
||||
await self._initialize_cache()
|
||||
else:
|
||||
# For subsequent refreshes, use fast reconciliation
|
||||
await self._reconcile_cache()
|
||||
|
||||
return self._cache
|
||||
|
||||
@@ -173,6 +178,135 @@ class ModelScanner:
|
||||
folders=[]
|
||||
)
|
||||
|
||||
async def _reconcile_cache(self) -> None:
|
||||
"""Fast cache reconciliation - only process differences between cache and filesystem"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...")
|
||||
|
||||
# Get current cached file paths
|
||||
cached_paths = {item['file_path'] for item in self._cache.raw_data}
|
||||
path_to_item = {item['file_path']: item for item in self._cache.raw_data}
|
||||
|
||||
# Track found files and new files
|
||||
found_paths = set()
|
||||
new_files = []
|
||||
|
||||
# Scan all model roots
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
# Track visited real paths to avoid symlink loops
|
||||
visited_real_paths = set()
|
||||
|
||||
# Recursively scan directory
|
||||
for root, _, files in os.walk(root_path, followlinks=True):
|
||||
real_root = os.path.realpath(root)
|
||||
if real_root in visited_real_paths:
|
||||
continue
|
||||
visited_real_paths.add(real_root)
|
||||
|
||||
for file in files:
|
||||
ext = os.path.splitext(file)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
# Construct paths exactly as they would be in cache
|
||||
file_path = os.path.join(root, file).replace(os.sep, '/')
|
||||
|
||||
# Check if this file is already in cache
|
||||
if file_path in cached_paths:
|
||||
found_paths.add(file_path)
|
||||
continue
|
||||
|
||||
# Try case-insensitive match on Windows
|
||||
if os.name == 'nt':
|
||||
lower_path = file_path.lower()
|
||||
matched = False
|
||||
for cached_path in cached_paths:
|
||||
if cached_path.lower() == lower_path:
|
||||
found_paths.add(cached_path)
|
||||
matched = True
|
||||
break
|
||||
if matched:
|
||||
continue
|
||||
|
||||
# This is a new file to process
|
||||
new_files.append(file_path)
|
||||
|
||||
# Yield control periodically
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Process new files in batches
|
||||
total_added = 0
|
||||
if new_files:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(new_files)} new files to process")
|
||||
batch_size = 50
|
||||
for i in range(0, len(new_files), batch_size):
|
||||
batch = new_files[i:i+batch_size]
|
||||
for path in batch:
|
||||
try:
|
||||
model_data = await self.scan_single_model(path)
|
||||
if model_data:
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
|
||||
# Update hash index if available
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Update tags count
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
total_added += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding {path} to cache: {e}")
|
||||
|
||||
# Yield control after each batch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Find missing files (in cache but not in filesystem)
|
||||
missing_files = cached_paths - found_paths
|
||||
total_removed = 0
|
||||
|
||||
if missing_files:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(missing_files)} files to remove from cache")
|
||||
|
||||
# Process files to remove
|
||||
for path in missing_files:
|
||||
try:
|
||||
model_to_remove = path_to_item[path]
|
||||
|
||||
# Update tags count
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
|
||||
if self._tags_count[tag] == 0:
|
||||
del self._tags_count[tag]
|
||||
|
||||
# Remove from hash index
|
||||
self._hash_index.remove_by_path(path)
|
||||
total_removed += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing {path} from cache: {e}")
|
||||
|
||||
# Update cache data
|
||||
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in missing_files]
|
||||
|
||||
# Resort cache if changes were made
|
||||
if total_added > 0 or total_removed > 0:
|
||||
# Update folders list
|
||||
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
||||
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||
|
||||
# These methods should be implemented in child classes
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
|
||||
Reference in New Issue
Block a user