mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
checkpoint
This commit is contained in:
@@ -5,10 +5,7 @@ from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .services.lora_scanner import LoraScanner
|
||||
from .services.checkpoint_scanner import CheckpointScanner
|
||||
from .services.recipe_scanner import RecipeScanner
|
||||
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
|
||||
from .services.service_registry import ServiceRegistry
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,8 +38,10 @@ class LoraManager:
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Get checkpoint scanner instance
|
||||
checkpoint_scanner = asyncio.run(ServiceRegistry.get_checkpoint_scanner())
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
checkpoint_scanner = CheckpointScanner()
|
||||
for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
@@ -79,51 +78,89 @@ class LoraManager:
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
|
||||
# Setup file monitoring
|
||||
lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots)
|
||||
lora_monitor.start()
|
||||
|
||||
checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots())
|
||||
checkpoint_monitor.start()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app, lora_monitor)
|
||||
ApiRoutes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
|
||||
# Store monitors in app for cleanup
|
||||
app['lora_monitor'] = lora_monitor
|
||||
app['checkpoint_monitor'] = checkpoint_monitor
|
||||
|
||||
logger.info("PromptServer app: ", app)
|
||||
|
||||
# Schedule cache initialization using the application's startup handler
|
||||
app.on_startup.append(lambda app: cls._schedule_cache_init(
|
||||
lora_routes.scanner,
|
||||
checkpoints_routes.scanner,
|
||||
lora_routes.recipe_scanner
|
||||
))
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
@classmethod
|
||||
async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner):
|
||||
"""Schedule cache initialization in the running event loop"""
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
logger.info("CivitaiClient registered in ServiceRegistry")
|
||||
|
||||
# Get file monitors through ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_lora_monitor()
|
||||
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
|
||||
|
||||
# Start monitors
|
||||
lora_monitor.start()
|
||||
logger.info("Lora monitor started")
|
||||
|
||||
# Make sure checkpoint monitor has paths before starting
|
||||
await checkpoint_monitor.initialize_paths()
|
||||
checkpoint_monitor.start()
|
||||
logger.info("Checkpoint monitor started")
|
||||
|
||||
# Register DownloadManager with ServiceRegistry
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
logger.info("DownloadManager registered in ServiceRegistry")
|
||||
|
||||
# Initialize WebSocket manager
|
||||
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||
logger.info("WebSocketManager registered in ServiceRegistry")
|
||||
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
|
||||
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup(cls, app):
|
||||
"""Cleanup resources"""
|
||||
if 'lora_monitor' in app:
|
||||
app['lora_monitor'].stop()
|
||||
"""Cleanup resources using ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Cleaning up services")
|
||||
|
||||
if 'checkpoint_monitor' in app:
|
||||
app['checkpoint_monitor'].stop()
|
||||
# Get monitors from ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
|
||||
if lora_monitor:
|
||||
lora_monitor.stop()
|
||||
logger.info("Stopped LoRA monitor")
|
||||
|
||||
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
|
||||
if checkpoint_monitor:
|
||||
checkpoint_monitor.stop()
|
||||
logger.info("Stopped checkpoint monitor")
|
||||
|
||||
# Close CivitaiClient gracefully
|
||||
civitai_client = await ServiceRegistry.get_service("civitai_client")
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
logger.info("Closed CivitaiClient connection")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
@@ -19,22 +19,33 @@ from .update_routes import UpdateRoutes
|
||||
from ..services.recipe_scanner import RecipeScanner
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiRoutes:
|
||||
"""API route handlers for LoRA management"""
|
||||
|
||||
def __init__(self, file_monitor: LoraFileMonitor):
|
||||
self.scanner = LoraScanner()
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.download_manager = DownloadManager(file_monitor)
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.civitai_client = None # Will be initialized in setup_routes
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls(monitor)
|
||||
routes = cls()
|
||||
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: routes.initialize_services())
|
||||
|
||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
@@ -63,19 +74,28 @@ class ApiRoutes:
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def get_loras(self, request: web.Request) -> web.Response:
|
||||
"""Handle paginated LoRA data request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = int(request.query.get('page_size', '20'))
|
||||
@@ -231,6 +251,9 @@ class ApiRoutes:
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all loras in the background"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
@@ -312,6 +335,9 @@ class ApiRoutes:
|
||||
|
||||
async def get_folders(self, request: web.Request) -> web.Response:
|
||||
"""Get all folders in the cache"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
return web.json_response({
|
||||
'folders': cache.folders
|
||||
@@ -320,6 +346,12 @@ class ApiRoutes:
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
versions = await self.civitai_client.get_model_versions(model_id)
|
||||
if not versions:
|
||||
@@ -353,9 +385,12 @@ class ApiRoutes:
|
||||
async def get_civitai_model(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID or hash"""
|
||||
try:
|
||||
model_version_id = request.match_info['modelVersionId']
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
if not model_version_id:
|
||||
hash = request.match_info['hash']
|
||||
hash = request.match_info.get('hash')
|
||||
model = await self.civitai_client.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
|
||||
@@ -370,6 +405,9 @@ class ApiRoutes:
|
||||
async def download_lora(self, request: web.Request) -> web.Response:
|
||||
async with self._download_lock:
|
||||
try:
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
@@ -447,6 +485,9 @@ class ApiRoutes:
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model move request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
|
||||
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
|
||||
@@ -485,12 +526,17 @@ class ApiRoutes:
|
||||
@classmethod
|
||||
async def cleanup(cls):
|
||||
"""Add cleanup method for application shutdown"""
|
||||
if hasattr(cls, '_instance'):
|
||||
await cls._instance.civitai_client.close()
|
||||
# Now we don't need to store an instance, as services are managed by ServiceRegistry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
@@ -536,6 +582,9 @@ class ApiRoutes:
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
@@ -574,6 +623,9 @@ class ApiRoutes:
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
@@ -619,6 +671,9 @@ class ApiRoutes:
|
||||
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk model move request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
|
||||
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
|
||||
@@ -677,6 +732,9 @@ class ApiRoutes:
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
@@ -736,6 +794,9 @@ class ApiRoutes:
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
@@ -761,6 +822,9 @@ class ApiRoutes:
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
@@ -785,6 +849,12 @@ class ApiRoutes:
|
||||
async def rename_lora(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a LoRA file and its associated files"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
new_file_name = data.get('new_file_name')
|
||||
@@ -891,7 +961,7 @@ class ApiRoutes:
|
||||
|
||||
# Update recipe files and cache if hash is available
|
||||
if hash_value:
|
||||
recipe_scanner = RecipeScanner(self.scanner)
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
|
||||
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from ..services.civitai_client import CivitaiClient
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
@@ -21,16 +22,24 @@ class CheckpointsRoutes:
|
||||
"""API routes for checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = CheckpointScanner()
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = DownloadManager()
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
@@ -488,10 +497,9 @@ class CheckpointsRoutes:
|
||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint download request"""
|
||||
async with self._download_lock:
|
||||
# Initialize DownloadManager with the file monitor if the scanner has one
|
||||
if not hasattr(self, 'download_manager') or self.download_manager is None:
|
||||
file_monitor = getattr(self.scanner, 'file_monitor', None)
|
||||
self.download_manager = DownloadManager(file_monitor)
|
||||
# Get the download manager from service registry if not already initialized
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
# Use the common download handler with model_type="checkpoint"
|
||||
return await ModelRouteUtils.handle_download_model(
|
||||
@@ -503,6 +511,9 @@ class CheckpointsRoutes:
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
|
||||
@@ -6,7 +6,8 @@ import logging
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..services.recipe_scanner import RecipeScanner
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings # Add this import
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
@@ -15,13 +16,24 @@ class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = LoraScanner()
|
||||
self.recipe_scanner = RecipeScanner(self.scanner)
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.scanner = None
|
||||
self.recipe_scanner = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
logger.info("LoraRoutes: Retrieved LoraScanner from ServiceRegistry")
|
||||
|
||||
if self.recipe_scanner is None:
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
logger.info("LoraRoutes: Retrieved RecipeScanner from ServiceRegistry")
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
@@ -58,7 +70,10 @@ class LoraRoutes:
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the LoraScanner is initializing
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
len(self.scanner._cache.raw_data) == 0 or
|
||||
@@ -66,30 +81,29 @@ class LoraRoutes:
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# 如果正在初始化,返回一个只包含加载提示的页面
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info("Loras page is initializing, returning loading page")
|
||||
else:
|
||||
# 正常流程 - 获取已经初始化好的缓存数据
|
||||
# Normal flow - get data from initialized cache
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
# 如果获取缓存失败,也显示初始化页面
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
@@ -114,7 +128,10 @@ class LoraRoutes:
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# 检查缓存初始化状态,与handle_loras_page保持一致的逻辑
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the RecipeScanner is initializing
|
||||
is_initializing = (
|
||||
self.recipe_scanner._cache is None or
|
||||
len(self.recipe_scanner._cache.raw_data) == 0 or
|
||||
@@ -183,5 +200,13 @@ class LoraRoutes:
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
|
||||
# Register routes
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
from ..workflow.parser import WorkflowParser
|
||||
from ..utils.utils import download_civitai_image
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,13 +24,24 @@ class RecipeRoutes:
|
||||
"""API route handlers for Recipe management"""
|
||||
|
||||
def __init__(self):
|
||||
self.recipe_scanner = RecipeScanner(LoraScanner())
|
||||
self.civitai_client = CivitaiClient()
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.recipe_scanner = None
|
||||
self.civitai_client = None
|
||||
self.parser = WorkflowParser()
|
||||
|
||||
# Pre-warm the cache
|
||||
self._init_cache_task = None
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
if self.recipe_scanner is None:
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
logger.info("RecipeRoutes: Retrieved RecipeScanner from ServiceRegistry")
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
logger.info("RecipeRoutes: Retrieved CivitaiClient from ServiceRegistry")
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
@@ -68,7 +80,10 @@ class RecipeRoutes:
|
||||
async def _init_cache(self, app):
|
||||
"""Initialize cache on startup"""
|
||||
try:
|
||||
# First, ensure the lora scanner is fully initialized
|
||||
# Initialize services first
|
||||
await self.init_services()
|
||||
|
||||
# Now that services are initialized, get the lora scanner
|
||||
lora_scanner = self.recipe_scanner._lora_scanner
|
||||
|
||||
# Get lora cache to ensure it's initialized
|
||||
@@ -86,6 +101,9 @@ class RecipeRoutes:
|
||||
async def get_recipes(self, request: web.Request) -> web.Response:
|
||||
"""API endpoint for getting paginated recipes"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Get query parameters with defaults
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = int(request.query.get('page_size', '20'))
|
||||
@@ -155,6 +173,9 @@ class RecipeRoutes:
|
||||
async def get_recipe_detail(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information about a specific recipe"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
recipe_id = request.match_info['recipe_id']
|
||||
|
||||
# Use the new get_recipe_by_id method from recipe_scanner
|
||||
@@ -208,6 +229,9 @@ class RecipeRoutes:
|
||||
"""Analyze an uploaded image or URL for recipe metadata"""
|
||||
temp_path = None
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if request contains multipart data (image) or JSON data (url)
|
||||
content_type = request.headers.get('Content-Type', '')
|
||||
|
||||
@@ -326,6 +350,9 @@ class RecipeRoutes:
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
"""Save a recipe to the recipes folder"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
# Process form data
|
||||
@@ -527,6 +554,9 @@ class RecipeRoutes:
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
"""Delete a recipe by ID"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
recipe_id = request.match_info['recipe_id']
|
||||
|
||||
# Get recipes directory
|
||||
@@ -574,6 +604,9 @@ class RecipeRoutes:
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Get top tags used in recipes"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Get limit parameter with default
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
@@ -606,6 +639,9 @@ class RecipeRoutes:
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in recipes"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Get all recipes from cache
|
||||
cache = await self.recipe_scanner.get_cached_data()
|
||||
|
||||
@@ -634,6 +670,9 @@ class RecipeRoutes:
|
||||
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||
"""Process a recipe image for sharing by adding metadata to EXIF"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
recipe_id = request.match_info['recipe_id']
|
||||
|
||||
# Get all recipes from cache
|
||||
@@ -693,6 +732,9 @@ class RecipeRoutes:
|
||||
async def download_shared_recipe(self, request: web.Request) -> web.Response:
|
||||
"""Serve a processed recipe image for download"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
recipe_id = request.match_info['recipe_id']
|
||||
|
||||
# Check if we have this shared recipe
|
||||
@@ -749,6 +791,9 @@ class RecipeRoutes:
|
||||
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
|
||||
"""Save a recipe from the LoRAs widget"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
# Process form data
|
||||
@@ -923,6 +968,9 @@ class RecipeRoutes:
|
||||
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||
"""Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
recipe_id = request.match_info['recipe_id']
|
||||
|
||||
# Get all recipes from cache
|
||||
@@ -1003,6 +1051,9 @@ class RecipeRoutes:
|
||||
async def update_recipe(self, request: web.Request) -> web.Response:
|
||||
"""Update recipe metadata (name and tags)"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
recipe_id = request.match_info['recipe_id']
|
||||
data = await request.json()
|
||||
|
||||
@@ -1030,6 +1081,9 @@ class RecipeRoutes:
|
||||
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||
"""Reconnect a deleted LoRA in a recipe to a local LoRA file"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Parse request data
|
||||
data = await request.json()
|
||||
|
||||
@@ -1140,6 +1194,9 @@ class RecipeRoutes:
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
"""Get recipes that use a specific Lora"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
lora_hash = request.query.get('hash')
|
||||
|
||||
# Hash is required
|
||||
|
||||
@@ -8,6 +8,7 @@ from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import aiohttp
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from email.parser import Parser
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from urllib.parse import unquote
|
||||
@@ -11,7 +12,23 @@ from ..utils.models import LoraMetadata
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CivitaiClient:
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of CivitaiClient"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self.base_url = "https://civitai.com/api/v1"
|
||||
self.headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
from .civitai_client import CivitaiClient
|
||||
from .file_monitor import LoraFileMonitor
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -14,9 +15,46 @@ import tempfile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.file_monitor = file_monitor
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of DownloadManager"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self._civitai_client = None # Will be lazily initialized
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
if self._civitai_client is None:
|
||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
return self._civitai_client
|
||||
|
||||
async def _get_lora_monitor(self):
|
||||
"""Get the lora file monitor from registry"""
|
||||
return await ServiceRegistry.get_lora_monitor()
|
||||
|
||||
async def _get_checkpoint_monitor(self):
|
||||
"""Get the checkpoint file monitor from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_monitor()
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
return await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
async def _get_checkpoint_scanner(self):
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
@@ -43,19 +81,22 @@ class DownloadManager:
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Get civitai client
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = None
|
||||
|
||||
if download_url:
|
||||
# Extract version ID from download URL
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info = await self.civitai_client.get_model_version_info(version_id)
|
||||
version_info = await civitai_client.get_model_version_info(version_id)
|
||||
elif model_version_id:
|
||||
# Use model version ID directly
|
||||
version_info = await self.civitai_client.get_model_version_info(model_version_id)
|
||||
version_info = await civitai_client.get_model_version_info(model_version_id)
|
||||
elif model_hash:
|
||||
# Get model by hash
|
||||
version_info = await self.civitai_client.get_model_by_hash(model_hash)
|
||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
||||
|
||||
|
||||
if not version_info:
|
||||
@@ -95,8 +136,9 @@ class DownloadManager:
|
||||
file_size = file_info.get('sizeKB', 0) * 1024
|
||||
|
||||
# 4. Notify file monitor - use normalized path and file size
|
||||
if self.file_monitor and self.file_monitor.handler:
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor()
|
||||
if file_monitor and file_monitor.handler:
|
||||
file_monitor.handler.add_ignore_path(
|
||||
save_path.replace(os.sep, '/'),
|
||||
file_size
|
||||
)
|
||||
@@ -112,7 +154,7 @@ class DownloadManager:
|
||||
# 5.1 Get and update model tags and description
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
|
||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
if model_metadata.get("tags"):
|
||||
metadata.tags = model_metadata.get("tags", [])
|
||||
@@ -146,6 +188,7 @@ class DownloadManager:
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
civitai_client = await self._get_civitai_client()
|
||||
save_path = metadata.file_path
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
@@ -165,7 +208,7 @@ class DownloadManager:
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
|
||||
# Download video directly
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
@@ -176,7 +219,7 @@ class DownloadManager:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Download the original image to temp path
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], temp_path):
|
||||
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
|
||||
# Optimize and convert to WebP
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
@@ -210,7 +253,7 @@ class DownloadManager:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
# Download model file with progress tracking
|
||||
success, result = await self.civitai_client._download_file(
|
||||
success, result = await civitai_client._download_file(
|
||||
download_url,
|
||||
save_dir,
|
||||
os.path.basename(save_path),
|
||||
@@ -232,13 +275,14 @@ class DownloadManager:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 6. Update cache based on model type
|
||||
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
|
||||
cache = await self.file_monitor.checkpoint_scanner.get_cached_data()
|
||||
if model_type == "checkpoint":
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
else:
|
||||
cache = await self.file_monitor.scanner.get_cached_data()
|
||||
scanner = await self._get_lora_scanner()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = relative_path
|
||||
cache.raw_data.append(metadata_dict)
|
||||
@@ -248,10 +292,7 @@ class DownloadManager:
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Update the hash index with the new model entry
|
||||
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
|
||||
self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
else:
|
||||
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Report 100% completion
|
||||
if progress_callback:
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
from operator import itemgetter
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from typing import List, Dict, Set
|
||||
from typing import List, Dict, Set, Optional
|
||||
from threading import Lock
|
||||
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
from .lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraFileHandler(FileSystemEventHandler):
|
||||
"""Handler for LoRA file system events"""
|
||||
class BaseFileHandler(FileSystemEventHandler):
|
||||
"""Base handler for file system events"""
|
||||
|
||||
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
|
||||
self.scanner = scanner
|
||||
self.loop = loop # 存储事件循环引用
|
||||
self.pending_changes = set() # 待处理的变更
|
||||
self.lock = Lock() # 线程安全锁
|
||||
self.update_task = None # 异步更新任务
|
||||
self._ignore_paths = set() # Add ignore paths set
|
||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self.loop = loop # Store event loop reference
|
||||
self.pending_changes = set() # Pending changes
|
||||
self.lock = Lock() # Thread-safe lock
|
||||
self.update_task = None # Async update task
|
||||
self._ignore_paths = set() # Paths to ignore
|
||||
self._min_ignore_timeout = 5 # Minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
|
||||
|
||||
# Track modified files with timestamps for debouncing
|
||||
self.modified_files: Dict[str, float] = {}
|
||||
self.debounce_timer = None
|
||||
self.debounce_delay = 3.0 # seconds to wait after last modification
|
||||
self.debounce_delay = 3.0 # Seconds to wait after last modification
|
||||
|
||||
# Track files that are already scheduled for processing
|
||||
# Track files already scheduled for processing
|
||||
self.scheduled_files: Set[str] = set()
|
||||
|
||||
# File extensions to monitor - should be overridden by subclasses
|
||||
self.file_extensions = set()
|
||||
|
||||
def _should_ignore(self, path: str) -> bool:
|
||||
"""Check if path should be ignored"""
|
||||
@@ -58,35 +58,33 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Handle safetensors files directly
|
||||
if event.src_path.endswith('.safetensors'):
|
||||
# Handle appropriate files based on extensions
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
# We'll process this file directly and ignore subsequent modifications
|
||||
# to prevent duplicate processing
|
||||
# Process this file directly and ignore subsequent modifications
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"LoRA file created: {event.src_path}")
|
||||
logger.info(f"File created: {event.src_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
# Ignore modifications for a short period after creation
|
||||
# This helps avoid duplicate processing
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
normalized_path
|
||||
)
|
||||
|
||||
# For browser downloads, we'll catch them when they're renamed to .safetensors
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Only process safetensors files
|
||||
if event.src_path.endswith('.safetensors'):
|
||||
# Only process files with supported extensions
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
@@ -134,12 +132,17 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
|
||||
# Process stable files
|
||||
for file_path in files_to_process:
|
||||
logger.info(f"Processing modified LoRA file: {file_path}")
|
||||
logger.info(f"Processing modified file: {file_path}")
|
||||
self._schedule_update('add', file_path)
|
||||
|
||||
def on_deleted(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext not in self.file_extensions:
|
||||
return
|
||||
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
@@ -147,14 +150,17 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
self.scheduled_files.discard(normalized_path)
|
||||
|
||||
logger.info(f"LoRA file deleted: {event.src_path}")
|
||||
logger.info(f"File deleted: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def on_moved(self, event):
|
||||
"""Handle file move/rename events"""
|
||||
|
||||
# If destination is a safetensors file, treat it as a new file
|
||||
if event.dest_path.endswith('.safetensors'):
|
||||
src_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
||||
|
||||
# If destination has supported extension, treat as new file
|
||||
if dest_ext in self.file_extensions:
|
||||
if self._should_ignore(event.dest_path):
|
||||
return
|
||||
|
||||
@@ -162,7 +168,7 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
|
||||
# Only process if not already scheduled
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"LoRA file renamed/moved to: {event.dest_path}")
|
||||
logger.info(f"File renamed/moved to: {event.dest_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.dest_path)
|
||||
|
||||
@@ -173,21 +179,21 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
normalized_path
|
||||
)
|
||||
|
||||
# If source was a safetensors file, treat it as deleted
|
||||
if event.src_path.endswith('.safetensors'):
|
||||
# If source was a supported file, treat it as deleted
|
||||
if src_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
self.scheduled_files.discard(normalized_path)
|
||||
|
||||
logger.info(f"LoRA file moved/renamed from: {event.src_path}")
|
||||
logger.info(f"File moved/renamed from: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
|
||||
def _schedule_update(self, action: str, file_path: str):
|
||||
"""Schedule a cache update"""
|
||||
with self.lock:
|
||||
# 使用 config 中的方法映射路径
|
||||
# Use config method to map path
|
||||
mapped_path = config.map_path_to_link(file_path)
|
||||
normalized_path = mapped_path.replace(os.sep, '/')
|
||||
self.pending_changes.add((action, normalized_path))
|
||||
@@ -198,7 +204,20 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
"""Create update task in the event loop"""
|
||||
if self.update_task is None or self.update_task.done():
|
||||
self.update_task = asyncio.create_task(self._process_changes())
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing - should be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement _process_changes")
|
||||
|
||||
|
||||
class LoraFileHandler(BaseFileHandler):
|
||||
"""Handler for LoRA file system events"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
# Set supported file extensions for LoRAs
|
||||
self.file_extensions = {'.safetensors'}
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing"""
|
||||
await asyncio.sleep(delay)
|
||||
@@ -211,9 +230,11 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} file changes")
|
||||
logger.info(f"Processing {len(changes)} LoRA file changes")
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
# Get scanner through ServiceRegistry
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
@@ -227,36 +248,36 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
continue
|
||||
|
||||
# Scan new file
|
||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||
if lora_data:
|
||||
model_data = await scanner.scan_single_model(file_path)
|
||||
if model_data:
|
||||
# Update tags count
|
||||
for tag in lora_data.get('tags', []):
|
||||
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
|
||||
for tag in model_data.get('tags', []):
|
||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(lora_data)
|
||||
new_folders.add(lora_data['folder'])
|
||||
cache.raw_data.append(model_data)
|
||||
new_folders.add(model_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in lora_data:
|
||||
self.scanner._hash_index.add_entry(
|
||||
lora_data['sha256'],
|
||||
lora_data['file_path']
|
||||
if 'sha256' in model_data:
|
||||
scanner._hash_index.add_entry(
|
||||
model_data['sha256'],
|
||||
model_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the lora to remove so we can update tags count
|
||||
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if lora_to_remove:
|
||||
# Find the model to remove so we can update tags count
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in lora_to_remove.get('tags', []):
|
||||
if tag in self.scanner._tags_count:
|
||||
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
|
||||
if self.scanner._tags_count[tag] == 0:
|
||||
del self.scanner._tags_count[tag]
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from cache")
|
||||
self.scanner._hash_index.remove_by_path(file_path)
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
@@ -274,59 +295,140 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes: {e}")
|
||||
logger.error(f"Error in process_changes for LoRA: {e}")
|
||||
|
||||
|
||||
class LoraFileMonitor:
|
||||
"""Monitor for LoRA file changes"""
|
||||
class CheckpointFileHandler(BaseFileHandler):
|
||||
"""Handler for checkpoint file system events"""
|
||||
|
||||
def __init__(self, scanner: LoraScanner, roots: List[str]):
|
||||
self.scanner = scanner
|
||||
scanner.set_file_monitor(self)
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
# Set supported file extensions for checkpoints
|
||||
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing for checkpoint files"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
with self.lock:
|
||||
changes = self.pending_changes.copy()
|
||||
self.pending_changes.clear()
|
||||
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} checkpoint file changes")
|
||||
|
||||
# Get scanner through ServiceRegistry
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# Check if file already exists in cache
|
||||
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if existing:
|
||||
logger.info(f"File {file_path} already in cache, skipping")
|
||||
continue
|
||||
|
||||
# Scan new file
|
||||
model_data = await scanner.scan_single_model(file_path)
|
||||
if model_data:
|
||||
# Update tags count if applicable
|
||||
for tag in model_data.get('tags', []):
|
||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(model_data)
|
||||
new_folders.add(model_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in model_data:
|
||||
scanner._hash_index.add_entry(
|
||||
model_data['sha256'],
|
||||
model_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the model to remove so we can update tags count
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from checkpoint cache")
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Update folder list
|
||||
all_folders = set(cache.folders) | new_folders
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes for checkpoint: {e}")
|
||||
|
||||
|
||||
class BaseFileMonitor:
|
||||
"""Base class for file monitoring"""
|
||||
|
||||
def __init__(self, monitor_paths: List[str]):
|
||||
self.observer = Observer()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.handler = LoraFileHandler(scanner, self.loop)
|
||||
|
||||
# 使用已存在的路径映射
|
||||
self.monitor_paths = set()
|
||||
for root in roots:
|
||||
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
|
||||
|
||||
# Process monitor paths
|
||||
for path in monitor_paths:
|
||||
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
|
||||
|
||||
# 添加所有已映射的目标路径
|
||||
# Add mapped paths from config
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
|
||||
def start(self):
|
||||
"""Start monitoring"""
|
||||
for path_info in self.monitor_paths:
|
||||
"""Start file monitoring"""
|
||||
for path in self.monitor_paths:
|
||||
try:
|
||||
if isinstance(path_info, tuple):
|
||||
# 对于链接,监控目标路径
|
||||
_, target_path = path_info
|
||||
self.observer.schedule(self.handler, target_path, recursive=True)
|
||||
logger.info(f"Started monitoring target path: {target_path}")
|
||||
else:
|
||||
# 对于普通路径,直接监控
|
||||
self.observer.schedule(self.handler, path_info, recursive=True)
|
||||
logger.info(f"Started monitoring: {path_info}")
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
logger.info(f"Started monitoring: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring {path_info}: {e}")
|
||||
logger.error(f"Error monitoring {path}: {e}")
|
||||
|
||||
self.observer.start()
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
"""Stop file monitoring"""
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
|
||||
def rescan_links(self):
|
||||
"""重新扫描链接(当添加新的链接时调用)"""
|
||||
"""Rescan links when new ones are added"""
|
||||
# Find new paths not yet being monitored
|
||||
new_paths = set()
|
||||
for path in self.monitor_paths.copy():
|
||||
self._add_link_targets(path)
|
||||
for path in config._path_mappings.keys():
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
if real_path not in self.monitor_paths:
|
||||
new_paths.add(real_path)
|
||||
self.monitor_paths.add(real_path)
|
||||
|
||||
# 添加新发现的路径到监控
|
||||
new_paths = self.monitor_paths - set(self.observer.watches.keys())
|
||||
# Add new paths to monitoring
|
||||
for path in new_paths:
|
||||
try:
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
@@ -334,88 +436,86 @@ class LoraFileMonitor:
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding new monitor for {path}: {e}")
|
||||
|
||||
# Add CheckpointFileMonitor class
|
||||
|
||||
class CheckpointFileMonitor(LoraFileMonitor):
|
||||
class LoraFileMonitor(BaseFileMonitor):
|
||||
"""Monitor for LoRA file changes"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls, monitor_paths=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, monitor_paths=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
if monitor_paths is None:
|
||||
from ..config import config
|
||||
monitor_paths = config.loras_roots
|
||||
|
||||
super().__init__(monitor_paths)
|
||||
self.handler = LoraFileHandler(self.loop)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
from ..config import config
|
||||
cls._instance = cls(config.loras_roots)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class CheckpointFileMonitor(BaseFileMonitor):
|
||||
"""Monitor for checkpoint file changes"""
|
||||
|
||||
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
|
||||
# Reuse most of the LoraFileMonitor functionality, but with a different handler
|
||||
self.scanner = scanner
|
||||
scanner.set_file_monitor(self)
|
||||
self.observer = Observer()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.handler = CheckpointFileHandler(scanner, self.loop)
|
||||
|
||||
# Use existing path mappings
|
||||
self.monitor_paths = set()
|
||||
for root in roots:
|
||||
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
|
||||
|
||||
# Add all mapped target paths
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
class CheckpointFileHandler(LoraFileHandler):
|
||||
"""Handler for checkpoint file system events"""
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(scanner, loop)
|
||||
# Configure supported file extensions
|
||||
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
|
||||
|
||||
def on_created(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Handle supported file extensions directly
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.supported_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
# Process this file directly
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"Checkpoint file created: {event.src_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
# Ignore modifications for a short period after creation
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
normalized_path
|
||||
)
|
||||
def __new__(cls, monitor_paths=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Only process supported file types
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.supported_extensions:
|
||||
super().on_modified(event)
|
||||
def __init__(self, monitor_paths=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
if monitor_paths is None:
|
||||
# Get checkpoint roots from scanner
|
||||
monitor_paths = []
|
||||
# We'll initialize monitor paths later when scanner is available
|
||||
|
||||
def on_deleted(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
super().__init__(monitor_paths or [])
|
||||
self.handler = CheckpointFileHandler(self.loop)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls([])
|
||||
|
||||
# Now get checkpoint roots from scanner
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
monitor_paths = scanner.get_model_roots()
|
||||
|
||||
# Update monitor paths
|
||||
for path in monitor_paths:
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
cls._instance.monitor_paths.add(real_path)
|
||||
|
||||
return cls._instance
|
||||
|
||||
async def initialize_paths(self):
|
||||
"""Initialize monitor paths from scanner"""
|
||||
if not self.monitor_paths:
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
monitor_paths = scanner.get_model_roots()
|
||||
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext not in self.supported_extensions:
|
||||
return
|
||||
|
||||
super().on_deleted(event)
|
||||
|
||||
def on_moved(self, event):
|
||||
"""Handle file move/rename events"""
|
||||
src_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
||||
|
||||
# If destination has supported extension, treat as new file
|
||||
if dest_ext in self.supported_extensions:
|
||||
super().on_moved(event)
|
||||
|
||||
# If source was supported extension, treat as deleted
|
||||
elif src_ext in self.supported_extensions:
|
||||
super().on_moved(event)
|
||||
# Update monitor paths
|
||||
for path in monitor_paths:
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
self.monitor_paths.add(real_path)
|
||||
@@ -13,6 +13,7 @@ from .lora_hash_index import LoraHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .service_registry import ServiceRegistry
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,13 +12,13 @@ from ..utils.file_utils import load_metadata, get_file_info, find_preview_file,
|
||||
from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||
@@ -35,14 +35,17 @@ class ModelScanner:
|
||||
self.file_extensions = file_extensions
|
||||
self._cache = None
|
||||
self._hash_index = hash_index or ModelHashIndex()
|
||||
self.file_monitor = None
|
||||
self._tags_count = {} # Dictionary to store tag counts
|
||||
self._is_initializing = False # Flag to track initialization state
|
||||
|
||||
# Register this service
|
||||
asyncio.create_task(self._register_service())
|
||||
|
||||
async def _register_service(self):
|
||||
"""Register this instance with the ServiceRegistry"""
|
||||
service_name = f"{self.model_type}_scanner"
|
||||
await ServiceRegistry.register_service(service_name, self)
|
||||
|
||||
def set_file_monitor(self, monitor):
|
||||
"""Set file monitor instance"""
|
||||
self.file_monitor = monitor
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
try:
|
||||
@@ -366,12 +369,20 @@ class ModelScanner:
|
||||
|
||||
file_size = os.path.getsize(real_source)
|
||||
|
||||
if self.file_monitor:
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
# Get the appropriate file monitor through ServiceRegistry
|
||||
if self.model_type == "lora":
|
||||
monitor = await ServiceRegistry.get_lora_monitor()
|
||||
elif self.model_type == "checkpoint":
|
||||
monitor = await ServiceRegistry.get_checkpoint_monitor()
|
||||
else:
|
||||
monitor = None
|
||||
|
||||
if monitor:
|
||||
monitor.handler.add_ignore_path(
|
||||
real_source,
|
||||
file_size
|
||||
)
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
monitor.handler.add_ignore_path(
|
||||
real_target,
|
||||
file_size
|
||||
)
|
||||
|
||||
@@ -5,8 +5,8 @@ import json
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from ..config import config
|
||||
from .recipe_cache import RecipeCache
|
||||
from .service_registry import ServiceRegistry
|
||||
from .lora_scanner import LoraScanner
|
||||
from .civitai_client import CivitaiClient
|
||||
from ..utils.utils import fuzzy_match
|
||||
import sys
|
||||
|
||||
@@ -18,11 +18,22 @@ class RecipeScanner:
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None):
|
||||
"""Get singleton instance of RecipeScanner"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
if not lora_scanner:
|
||||
# Get lora scanner from service registry if not provided
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
cls._instance = cls(lora_scanner)
|
||||
return cls._instance
|
||||
|
||||
def __new__(cls, lora_scanner: Optional[LoraScanner] = None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._lora_scanner = lora_scanner
|
||||
cls._instance._civitai_client = CivitaiClient()
|
||||
cls._instance._civitai_client = None # Will be lazily initialized
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, lora_scanner: Optional[LoraScanner] = None):
|
||||
@@ -36,6 +47,12 @@ class RecipeScanner:
|
||||
self._lora_scanner = lora_scanner
|
||||
self._initialized = True
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
if self._civitai_client is None:
|
||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
return self._civitai_client
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
try:
|
||||
@@ -306,10 +323,13 @@ class RecipeScanner:
|
||||
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
||||
"""Get hash from Civitai API"""
|
||||
try:
|
||||
if not self._civitai_client:
|
||||
# Get CivitaiClient from ServiceRegistry
|
||||
civitai_client = await self._get_civitai_client()
|
||||
if not civitai_client:
|
||||
logger.error("Failed to get CivitaiClient from ServiceRegistry")
|
||||
return None
|
||||
|
||||
version_info = await self._civitai_client.get_model_version_info(model_version_id)
|
||||
version_info = await civitai_client.get_model_version_info(model_version_id)
|
||||
|
||||
if not version_info or not version_info.get('files'):
|
||||
logger.debug(f"No files found in version info for ID: {model_version_id}")
|
||||
@@ -329,10 +349,12 @@ class RecipeScanner:
|
||||
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
|
||||
"""Get model version name from Civitai API"""
|
||||
try:
|
||||
if not self._civitai_client:
|
||||
# Get CivitaiClient from ServiceRegistry
|
||||
civitai_client = await self._get_civitai_client()
|
||||
if not civitai_client:
|
||||
return None
|
||||
|
||||
version_info = await self._civitai_client.get_model_version_info(model_version_id)
|
||||
version_info = await civitai_client.get_model_version_info(model_version_id)
|
||||
|
||||
if version_info and 'name' in version_info:
|
||||
return version_info['name']
|
||||
|
||||
124
py/services/service_registry.py
Normal file
124
py/services/service_registry.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, TypeVar, Type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T') # Define a type variable for service types
|
||||
|
||||
class ServiceRegistry:
|
||||
"""Centralized registry for service singletons"""
|
||||
|
||||
_instance = None
|
||||
_services: Dict[str, Any] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get singleton instance of the registry"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
||||
"""Register a service instance with the registry"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
registry._services[service_name] = service_instance
|
||||
logger.debug(f"Registered service: {service_name}")
|
||||
|
||||
@classmethod
|
||||
async def get_service(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
if service_name not in registry._services:
|
||||
logger.warning(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
|
||||
# Convenience methods for common services
|
||||
@classmethod
|
||||
async def get_lora_scanner(cls):
|
||||
"""Get the LoraScanner instance"""
|
||||
from .lora_scanner import LoraScanner
|
||||
scanner = await cls.get_service("lora_scanner")
|
||||
if scanner is None:
|
||||
scanner = await LoraScanner.get_instance()
|
||||
await cls.register_service("lora_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_checkpoint_scanner(cls):
|
||||
"""Get the CheckpointScanner instance"""
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await cls.get_service("checkpoint_scanner")
|
||||
if scanner is None:
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
await cls.register_service("checkpoint_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_lora_monitor(cls):
|
||||
"""Get the LoraFileMonitor instance"""
|
||||
from .file_monitor import LoraFileMonitor
|
||||
monitor = await cls.get_service("lora_monitor")
|
||||
if monitor is None:
|
||||
monitor = await LoraFileMonitor.get_instance()
|
||||
await cls.register_service("lora_monitor", monitor)
|
||||
return monitor
|
||||
|
||||
@classmethod
|
||||
async def get_checkpoint_monitor(cls):
|
||||
"""Get the CheckpointFileMonitor instance"""
|
||||
from .file_monitor import CheckpointFileMonitor
|
||||
monitor = await cls.get_service("checkpoint_monitor")
|
||||
if monitor is None:
|
||||
monitor = await CheckpointFileMonitor.get_instance()
|
||||
await cls.register_service("checkpoint_monitor", monitor)
|
||||
return monitor
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get the CivitaiClient instance"""
|
||||
from .civitai_client import CivitaiClient
|
||||
client = await cls.get_service("civitai_client")
|
||||
if client is None:
|
||||
client = await CivitaiClient.get_instance()
|
||||
await cls.register_service("civitai_client", client)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get the DownloadManager instance"""
|
||||
from .download_manager import DownloadManager
|
||||
manager = await cls.get_service("download_manager")
|
||||
if manager is None:
|
||||
# We'll let DownloadManager.get_instance handle file_monitor parameter
|
||||
manager = await DownloadManager.get_instance()
|
||||
await cls.register_service("download_manager", manager)
|
||||
return manager
|
||||
|
||||
@classmethod
|
||||
async def get_recipe_scanner(cls):
|
||||
"""Get the RecipeScanner instance"""
|
||||
from .recipe_scanner import RecipeScanner
|
||||
scanner = await cls.get_service("recipe_scanner")
|
||||
if scanner is None:
|
||||
lora_scanner = await cls.get_lora_scanner()
|
||||
scanner = RecipeScanner(lora_scanner)
|
||||
await cls.register_service("recipe_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_websocket_manager(cls):
|
||||
"""Get the WebSocketManager instance"""
|
||||
from .websocket_manager import ws_manager
|
||||
manager = await cls.get_service("websocket_manager")
|
||||
if manager is None:
|
||||
# ws_manager is already a global instance in websocket_manager.py
|
||||
from .websocket_manager import ws_manager
|
||||
await cls.register_service("websocket_manager", ws_manager)
|
||||
manager = ws_manager
|
||||
return manager
|
||||
Reference in New Issue
Block a user