checkpoint

This commit is contained in:
Will Miao
2025-04-11 20:22:12 +08:00
parent 1db49a4dd4
commit 0618541527
13 changed files with 793 additions and 276 deletions

View File

@@ -5,10 +5,7 @@ from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .services.lora_scanner import LoraScanner
from .services.checkpoint_scanner import CheckpointScanner
from .services.recipe_scanner import RecipeScanner
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
from .services.service_registry import ServiceRegistry
import logging
logger = logging.getLogger(__name__)
@@ -41,8 +38,10 @@ class LoraManager:
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Get checkpoint scanner instance
checkpoint_scanner = asyncio.run(ServiceRegistry.get_checkpoint_scanner())
# Add static routes for each checkpoint root
checkpoint_scanner = CheckpointScanner()
for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1):
preview_path = f'/checkpoints_static/root{idx}/preview'
@@ -79,51 +78,89 @@ class LoraManager:
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
# Setup file monitoring
lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots)
lora_monitor.start()
checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots())
checkpoint_monitor.start()
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app, lora_monitor)
ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app)
# Store monitors in app for cleanup
app['lora_monitor'] = lora_monitor
app['checkpoint_monitor'] = checkpoint_monitor
logger.info("PromptServer app: ", app)
# Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(
lora_routes.scanner,
checkpoints_routes.scanner,
lora_routes.recipe_scanner
))
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
@classmethod
async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner):
"""Schedule cache initialization in the running event loop"""
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
try:
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
# Initialize CivitaiClient first to ensure it's ready for other services
civitai_client = await ServiceRegistry.get_civitai_client()
logger.info("CivitaiClient registered in ServiceRegistry")
# Get file monitors through ServiceRegistry
lora_monitor = await ServiceRegistry.get_lora_monitor()
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
# Start monitors
lora_monitor.start()
logger.info("Lora monitor started")
# Make sure checkpoint monitor has paths before starting
await checkpoint_monitor.initialize_paths()
checkpoint_monitor.start()
logger.info("Checkpoint monitor started")
# Register DownloadManager with ServiceRegistry
download_manager = await ServiceRegistry.get_download_manager()
logger.info("DownloadManager registered in ServiceRegistry")
# Initialize WebSocket manager
ws_manager = await ServiceRegistry.get_websocket_manager()
logger.info("WebSocketManager registered in ServiceRegistry")
# Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks
lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
except Exception as e:
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
@classmethod
async def _cleanup(cls, app):
"""Cleanup resources"""
if 'lora_monitor' in app:
app['lora_monitor'].stop()
"""Cleanup resources using ServiceRegistry"""
try:
logger.info("LoRA Manager: Cleaning up services")
if 'checkpoint_monitor' in app:
app['checkpoint_monitor'].stop()
# Get monitors from ServiceRegistry
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
if lora_monitor:
lora_monitor.stop()
logger.info("Stopped LoRA monitor")
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
if checkpoint_monitor:
checkpoint_monitor.stop()
logger.info("Stopped checkpoint monitor")
# Close CivitaiClient gracefully
civitai_client = await ServiceRegistry.get_service("civitai_client")
if civitai_client:
await civitai_client.close()
logger.info("Closed CivitaiClient connection")
except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -19,22 +19,33 @@ from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.civitai_client = None # Will be initialized in setup_routes
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.download_manager = await ServiceRegistry.get_download_manager()
@classmethod
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls(monitor)
routes = cls()
# Schedule service initialization on app startup
app.on_startup.append(lambda _: routes.initialize_services())
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
@@ -63,19 +74,28 @@ class ApiRoutes:
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
@@ -231,6 +251,9 @@ class ApiRoutes:
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
@@ -312,6 +335,9 @@ class ApiRoutes:
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
@@ -320,6 +346,12 @@ class ApiRoutes:
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
versions = await self.civitai_client.get_model_versions(model_id)
if not versions:
@@ -353,9 +385,12 @@ class ApiRoutes:
async def get_civitai_model(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID or hash"""
try:
model_version_id = request.match_info['modelVersionId']
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_version_id = request.match_info.get('modelVersionId')
if not model_version_id:
hash = request.match_info['hash']
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
@@ -370,6 +405,9 @@ class ApiRoutes:
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
try:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
@@ -447,6 +485,9 @@ class ApiRoutes:
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
@@ -485,12 +526,17 @@ class ApiRoutes:
@classmethod
async def cleanup(cls):
"""Add cleanup method for application shutdown"""
if hasattr(cls, '_instance'):
await cls._instance.civitai_client.close()
# Now we don't need to store an instance, as services are managed by ServiceRegistry
civitai_client = await ServiceRegistry.get_civitai_client()
if civitai_client:
await civitai_client.close()
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
@@ -536,6 +582,9 @@ class ApiRoutes:
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -574,6 +623,9 @@ class ApiRoutes:
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -619,6 +671,9 @@ class ApiRoutes:
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
@@ -677,6 +732,9 @@ class ApiRoutes:
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
@@ -736,6 +794,9 @@ class ApiRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -761,6 +822,9 @@ class ApiRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -785,6 +849,12 @@ class ApiRoutes:
async def rename_lora(self, request: web.Request) -> web.Response:
"""Handle renaming a LoRA file and its associated files"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
file_path = data.get('file_path')
new_file_name = data.get('new_file_name')
@@ -891,7 +961,7 @@ class ApiRoutes:
# Update recipe files and cache if hash is available
if hash_value:
recipe_scanner = RecipeScanner(self.scanner)
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")

View File

@@ -11,6 +11,7 @@ from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner
from ..services.download_manager import DownloadManager
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
@@ -21,16 +22,24 @@ class CheckpointsRoutes:
"""API routes for checkpoint management"""
def __init__(self):
self.scanner = CheckpointScanner()
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = DownloadManager()
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
@@ -488,10 +497,9 @@ class CheckpointsRoutes:
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Initialize DownloadManager with the file monitor if the scanner has one
if not hasattr(self, 'download_manager') or self.download_manager is None:
file_monitor = getattr(self.scanner, 'file_monitor', None)
self.download_manager = DownloadManager(file_monitor)
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
# Use the common download handler with model_type="checkpoint"
return await ModelRouteUtils.handle_download_model(
@@ -503,6 +511,9 @@ class CheckpointsRoutes:
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,

View File

@@ -6,7 +6,8 @@ import logging
from ..services.lora_scanner import LoraScanner
from ..services.recipe_scanner import RecipeScanner
from ..config import config
from ..services.settings_manager import settings # Add this import
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
@@ -15,13 +16,24 @@ class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
def __init__(self):
self.scanner = LoraScanner()
self.recipe_scanner = RecipeScanner(self.scanner)
# Initialize service references as None, will be set during async init
self.scanner = None
self.recipe_scanner = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
logger.info("LoraRoutes: Retrieved LoraScanner from ServiceRegistry")
if self.recipe_scanner is None:
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
logger.info("LoraRoutes: Retrieved RecipeScanner from ServiceRegistry")
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
@@ -58,7 +70,10 @@ class LoraRoutes:
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# 检查缓存初始化状态根据initialize_in_background的工作方式调整判断逻辑
# Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
is_initializing = (
self.scanner._cache is None or
len(self.scanner._cache.raw_data) == 0 or
@@ -66,30 +81,29 @@ class LoraRoutes:
)
if is_initializing:
# 如果正在初始化,返回一个只包含加载提示的页面
# If still initializing, return loading page
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Loras page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
# Normal flow - get data from initialized cache
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
@@ -114,7 +128,10 @@ class LoraRoutes:
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# 检查缓存初始化状态与handle_loras_page保持一致的逻辑
# Ensure services are initialized
await self.init_services()
# Check if the RecipeScanner is initializing
is_initializing = (
self.recipe_scanner._cache is None or
len(self.recipe_scanner._cache.raw_data) == 0 or
@@ -183,5 +200,13 @@ class LoraRoutes:
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
# Register routes
app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()

View File

@@ -16,6 +16,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config
from ..workflow.parser import WorkflowParser
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__)
@@ -23,13 +24,24 @@ class RecipeRoutes:
"""API route handlers for Recipe management"""
def __init__(self):
self.recipe_scanner = RecipeScanner(LoraScanner())
self.civitai_client = CivitaiClient()
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
self.parser = WorkflowParser()
# Pre-warm the cache
self._init_cache_task = None
async def init_services(self):
"""Initialize services from ServiceRegistry"""
if self.recipe_scanner is None:
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
logger.info("RecipeRoutes: Retrieved RecipeScanner from ServiceRegistry")
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
logger.info("RecipeRoutes: Retrieved CivitaiClient from ServiceRegistry")
@classmethod
def setup_routes(cls, app: web.Application):
"""Register API routes"""
@@ -68,7 +80,10 @@ class RecipeRoutes:
async def _init_cache(self, app):
"""Initialize cache on startup"""
try:
# First, ensure the lora scanner is fully initialized
# Initialize services first
await self.init_services()
# Now that services are initialized, get the lora scanner
lora_scanner = self.recipe_scanner._lora_scanner
# Get lora cache to ensure it's initialized
@@ -86,6 +101,9 @@ class RecipeRoutes:
async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get query parameters with defaults
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
@@ -155,6 +173,9 @@ class RecipeRoutes:
async def get_recipe_detail(self, request: web.Request) -> web.Response:
"""Get detailed information about a specific recipe"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Use the new get_recipe_by_id method from recipe_scanner
@@ -208,6 +229,9 @@ class RecipeRoutes:
"""Analyze an uploaded image or URL for recipe metadata"""
temp_path = None
try:
# Ensure services are initialized
await self.init_services()
# Check if request contains multipart data (image) or JSON data (url)
content_type = request.headers.get('Content-Type', '')
@@ -326,6 +350,9 @@ class RecipeRoutes:
async def save_recipe(self, request: web.Request) -> web.Response:
"""Save a recipe to the recipes folder"""
try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Process form data
@@ -527,6 +554,9 @@ class RecipeRoutes:
async def delete_recipe(self, request: web.Request) -> web.Response:
"""Delete a recipe by ID"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get recipes directory
@@ -574,6 +604,9 @@ class RecipeRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Get top tags used in recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get limit parameter with default
limit = int(request.query.get('limit', '20'))
@@ -606,6 +639,9 @@ class RecipeRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
@@ -634,6 +670,9 @@ class RecipeRoutes:
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
@@ -693,6 +732,9 @@ class RecipeRoutes:
async def download_shared_recipe(self, request: web.Request) -> web.Response:
"""Serve a processed recipe image for download"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Check if we have this shared recipe
@@ -749,6 +791,9 @@ class RecipeRoutes:
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
"""Save a recipe from the LoRAs widget"""
try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Process form data
@@ -923,6 +968,9 @@ class RecipeRoutes:
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
"""Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
@@ -1003,6 +1051,9 @@ class RecipeRoutes:
async def update_recipe(self, request: web.Request) -> web.Response:
"""Update recipe metadata (name and tags)"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
data = await request.json()
@@ -1030,6 +1081,9 @@ class RecipeRoutes:
async def reconnect_lora(self, request: web.Request) -> web.Response:
"""Reconnect a deleted LoRA in a recipe to a local LoRA file"""
try:
# Ensure services are initialized
await self.init_services()
# Parse request data
data = await request.json()
@@ -1140,6 +1194,9 @@ class RecipeRoutes:
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
"""Get recipes that use a specific Lora"""
try:
# Ensure services are initialized
await self.init_services()
lora_hash = request.query.get('hash')
# Hash is required

View File

@@ -8,6 +8,7 @@ from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import aiohttp
import os
import json
import logging
import asyncio
from email.parser import Parser
from typing import Optional, Dict, Tuple, List
from urllib.parse import unquote
@@ -11,7 +12,23 @@ from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class CivitaiClient:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of CivitaiClient"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'

View File

@@ -1,12 +1,13 @@
import logging
import os
import json
from typing import Optional, Dict
import asyncio
from typing import Optional, Dict, Any
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .service_registry import ServiceRegistry
# Download to temporary file first
import tempfile
@@ -14,9 +15,46 @@ import tempfile
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of DownloadManager"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self._civitai_client = None # Will be lazily initialized
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_monitor(self):
"""Get the lora file monitor from registry"""
return await ServiceRegistry.get_lora_monitor()
async def _get_checkpoint_monitor(self):
"""Get the checkpoint file monitor from registry"""
return await ServiceRegistry.get_checkpoint_monitor()
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
return await ServiceRegistry.get_lora_scanner()
async def _get_checkpoint_scanner(self):
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
model_version_id: str = None, save_dir: str = None,
@@ -43,19 +81,22 @@ class DownloadManager:
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get civitai client
civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier
version_info = None
if download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
version_info = await civitai_client.get_model_version_info(version_id)
elif model_version_id:
# Use model version ID directly
version_info = await self.civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
elif model_hash:
# Get model by hash
version_info = await self.civitai_client.get_model_by_hash(model_hash)
version_info = await civitai_client.get_model_by_hash(model_hash)
if not version_info:
@@ -95,8 +136,9 @@ class DownloadManager:
file_size = file_info.get('sizeKB', 0) * 1024
# 4. Notify file monitor - use normalized path and file size
if self.file_monitor and self.file_monitor.handler:
self.file_monitor.handler.add_ignore_path(
file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor()
if file_monitor and file_monitor.handler:
file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
@@ -112,7 +154,7 @@ class DownloadManager:
# 5.1 Get and update model tags and description
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
@@ -146,6 +188,7 @@ class DownloadManager:
model_type: str = "lora") -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
@@ -165,7 +208,7 @@ class DownloadManager:
preview_path = os.path.splitext(save_path)[0] + preview_ext
# Download video directly
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
@@ -176,7 +219,7 @@ class DownloadManager:
temp_path = temp_file.name
# Download the original image to temp path
if await self.civitai_client.download_preview_image(images[0]['url'], temp_path):
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
# Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp'
@@ -210,7 +253,7 @@ class DownloadManager:
await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking
success, result = await self.civitai_client._download_file(
success, result = await civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path),
@@ -232,13 +275,14 @@ class DownloadManager:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. Update cache based on model type
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
cache = await self.file_monitor.checkpoint_scanner.get_cached_data()
if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}")
else:
cache = await self.file_monitor.scanner.get_cached_data()
scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}")
cache = await scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
@@ -248,10 +292,7 @@ class DownloadManager:
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index with the new model entry
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
else:
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Report 100% completion
if progress_callback:

View File

@@ -1,39 +1,39 @@
from operator import itemgetter
import os
import logging
import asyncio
import time
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from typing import List, Dict, Set
from typing import List, Dict, Set, Optional
from threading import Lock
from .checkpoint_scanner import CheckpointScanner
from .lora_scanner import LoraScanner
from ..config import config
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler):
"""Handler for LoRA file system events"""
class BaseFileHandler(FileSystemEventHandler):
"""Base handler for file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
self.scanner = scanner
self.loop = loop # 存储事件循环引用
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # Store event loop reference
self.pending_changes = set() # Pending changes
self.lock = Lock() # Thread-safe lock
self.update_task = None # Async update task
self._ignore_paths = set() # Paths to ignore
self._min_ignore_timeout = 5 # Minimum timeout in seconds
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
# Track modified files with timestamps for debouncing
self.modified_files: Dict[str, float] = {}
self.debounce_timer = None
self.debounce_delay = 3.0 # seconds to wait after last modification
self.debounce_delay = 3.0 # Seconds to wait after last modification
# Track files that are already scheduled for processing
# Track files already scheduled for processing
self.scheduled_files: Set[str] = set()
# File extensions to monitor - should be overridden by subclasses
self.file_extensions = set()
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
@@ -58,35 +58,33 @@ class LoraFileHandler(FileSystemEventHandler):
if event.is_directory:
return
# Handle safetensors files directly
if event.src_path.endswith('.safetensors'):
# Handle appropriate files based on extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
# We'll process this file directly and ignore subsequent modifications
# to prevent duplicate processing
# Process this file directly and ignore subsequent modifications
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"LoRA file created: {event.src_path}")
logger.info(f"File created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
# This helps avoid duplicate processing
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
# For browser downloads, we'll catch them when they're renamed to .safetensors
def on_modified(self, event):
if event.is_directory:
return
# Only process safetensors files
if event.src_path.endswith('.safetensors'):
# Only process files with supported extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
@@ -134,12 +132,17 @@ class LoraFileHandler(FileSystemEventHandler):
# Process stable files
for file_path in files_to_process:
logger.info(f"Processing modified LoRA file: {file_path}")
logger.info(f"Processing modified file: {file_path}")
self._schedule_update('add', file_path)
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
if event.is_directory:
return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.file_extensions:
return
if self._should_ignore(event.src_path):
return
@@ -147,14 +150,17 @@ class LoraFileHandler(FileSystemEventHandler):
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"LoRA file deleted: {event.src_path}")
logger.info(f"File deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def on_moved(self, event):
"""Handle file move/rename events"""
# If destination is a safetensors file, treat it as a new file
if event.dest_path.endswith('.safetensors'):
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.file_extensions:
if self._should_ignore(event.dest_path):
return
@@ -162,7 +168,7 @@ class LoraFileHandler(FileSystemEventHandler):
# Only process if not already scheduled
if normalized_path not in self.scheduled_files:
logger.info(f"LoRA file renamed/moved to: {event.dest_path}")
logger.info(f"File renamed/moved to: {event.dest_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.dest_path)
@@ -173,21 +179,21 @@ class LoraFileHandler(FileSystemEventHandler):
normalized_path
)
# If source was a safetensors file, treat it as deleted
if event.src_path.endswith('.safetensors'):
# If source was a supported file, treat it as deleted
if src_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"LoRA file moved/renamed from: {event.src_path}")
logger.info(f"File moved/renamed from: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
def _schedule_update(self, action: str, file_path: str):
"""Schedule a cache update"""
with self.lock:
# 使用 config 中的方法映射路径
# Use config method to map path
mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
@@ -198,7 +204,20 @@ class LoraFileHandler(FileSystemEventHandler):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing - should be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement _process_changes")
class LoraFileHandler(BaseFileHandler):
"""Handler for LoRA file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for LoRAs
self.file_extensions = {'.safetensors'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
@@ -211,9 +230,11 @@ class LoraFileHandler(FileSystemEventHandler):
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
logger.info(f"Processing {len(changes)} LoRA file changes")
cache = await self.scanner.get_cached_data()
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
@@ -227,36 +248,36 @@ class LoraFileHandler(FileSystemEventHandler):
continue
# Scan new file
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count
for tag in lora_data.get('tags', []):
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder'])
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in lora_data:
self.scanner._hash_index.add_entry(
lora_data['sha256'],
lora_data['file_path']
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the lora to remove so we can update tags count
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_to_remove:
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in lora_to_remove.get('tags', []):
if tag in self.scanner._tags_count:
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
if self.scanner._tags_count[tag] == 0:
del self.scanner._tags_count[tag]
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from cache")
self.scanner._hash_index.remove_by_path(file_path)
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
@@ -274,59 +295,140 @@ class LoraFileHandler(FileSystemEventHandler):
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes: {e}")
logger.error(f"Error in process_changes for LoRA: {e}")
class LoraFileMonitor:
"""Monitor for LoRA file changes"""
class CheckpointFileHandler(BaseFileHandler):
"""Handler for checkpoint file system events"""
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
scanner.set_file_monitor(self)
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for checkpoints
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing for checkpoint files"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} checkpoint file changes")
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
for action, file_path in changes:
try:
if action == 'add':
# Check if file already exists in cache
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if existing:
logger.info(f"File {file_path} already in cache, skipping")
continue
# Scan new file
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count if applicable
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from checkpoint cache")
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# Update folder list
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes for checkpoint: {e}")
class BaseFileMonitor:
"""Base class for file monitoring"""
def __init__(self, monitor_paths: List[str]):
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
# 使用已存在的路径映射
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Process monitor paths
for path in monitor_paths:
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
# 添加所有已映射的目标路径
# Add mapped paths from config
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
def start(self):
"""Start monitoring"""
for path_info in self.monitor_paths:
"""Start file monitoring"""
for path in self.monitor_paths:
try:
if isinstance(path_info, tuple):
# 对于链接,监控目标路径
_, target_path = path_info
self.observer.schedule(self.handler, target_path, recursive=True)
logger.info(f"Started monitoring target path: {target_path}")
else:
# 对于普通路径,直接监控
self.observer.schedule(self.handler, path_info, recursive=True)
logger.info(f"Started monitoring: {path_info}")
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Started monitoring: {path}")
except Exception as e:
logger.error(f"Error monitoring {path_info}: {e}")
logger.error(f"Error monitoring {path}: {e}")
self.observer.start()
def stop(self):
"""Stop monitoring"""
"""Stop file monitoring"""
self.observer.stop()
self.observer.join()
def rescan_links(self):
"""重新扫描链接(当添加新的链接时调用)"""
"""Rescan links when new ones are added"""
# Find new paths not yet being monitored
new_paths = set()
for path in self.monitor_paths.copy():
self._add_link_targets(path)
for path in config._path_mappings.keys():
real_path = os.path.realpath(path).replace(os.sep, '/')
if real_path not in self.monitor_paths:
new_paths.add(real_path)
self.monitor_paths.add(real_path)
# 添加新发现的路径到监控
new_paths = self.monitor_paths - set(self.observer.watches.keys())
# Add new paths to monitoring
for path in new_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
@@ -334,88 +436,86 @@ class LoraFileMonitor:
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")
# Add CheckpointFileMonitor class
class CheckpointFileMonitor(LoraFileMonitor):
class LoraFileMonitor(BaseFileMonitor):
"""Monitor for LoRA file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
from ..config import config
monitor_paths = config.loras_roots
super().__init__(monitor_paths)
self.handler = LoraFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
from ..config import config
cls._instance = cls(config.loras_roots)
return cls._instance
class CheckpointFileMonitor(BaseFileMonitor):
"""Monitor for checkpoint file changes"""
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
# Reuse most of the LoraFileMonitor functionality, but with a different handler
self.scanner = scanner
scanner.set_file_monitor(self)
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = CheckpointFileHandler(scanner, self.loop)
# Use existing path mappings
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Add all mapped target paths
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
class CheckpointFileHandler(LoraFileHandler):
"""Handler for checkpoint file system events"""
_instance = None
_lock = asyncio.Lock()
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
super().__init__(scanner, loop)
# Configure supported file extensions
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
def on_created(self, event):
if event.is_directory:
return
# Handle supported file extensions directly
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
if self._should_ignore(event.src_path):
return
# Process this file directly
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"Checkpoint file created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def on_modified(self, event):
if event.is_directory:
return
# Only process supported file types
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
super().on_modified(event)
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
# Get checkpoint roots from scanner
monitor_paths = []
# We'll initialize monitor paths later when scanner is available
def on_deleted(self, event):
if event.is_directory:
return
super().__init__(monitor_paths or [])
self.handler = CheckpointFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls([])
# Now get checkpoint roots from scanner
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
monitor_paths = scanner.get_model_roots()
# Update monitor paths
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
cls._instance.monitor_paths.add(real_path)
return cls._instance
async def initialize_paths(self):
"""Initialize monitor paths from scanner"""
if not self.monitor_paths:
scanner = await ServiceRegistry.get_checkpoint_scanner()
monitor_paths = scanner.get_model_roots()
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.supported_extensions:
return
super().on_deleted(event)
def on_moved(self, event):
"""Handle file move/rename events"""
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.supported_extensions:
super().on_moved(event)
# If source was supported extension, treat as deleted
elif src_ext in self.supported_extensions:
super().on_moved(event)
# Update monitor paths
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
self.monitor_paths.add(real_path)

View File

@@ -13,6 +13,7 @@ from .lora_hash_index import LoraHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys
logger = logging.getLogger(__name__)

View File

@@ -12,13 +12,13 @@ from ..utils.file_utils import load_metadata, get_file_info, find_preview_file,
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class ModelScanner:
"""Base service for scanning and managing model files"""
_instance = None
_lock = asyncio.Lock()
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
@@ -35,14 +35,17 @@ class ModelScanner:
self.file_extensions = file_extensions
self._cache = None
self._hash_index = hash_index or ModelHashIndex()
self.file_monitor = None
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
# Register this service
asyncio.create_task(self._register_service())
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -366,12 +369,20 @@ class ModelScanner:
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
# Get the appropriate file monitor through ServiceRegistry
if self.model_type == "lora":
monitor = await ServiceRegistry.get_lora_monitor()
elif self.model_type == "checkpoint":
monitor = await ServiceRegistry.get_checkpoint_monitor()
else:
monitor = None
if monitor:
monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
monitor.handler.add_ignore_path(
real_target,
file_size
)

View File

@@ -5,8 +5,8 @@ import json
from typing import List, Dict, Optional, Any, Tuple
from ..config import config
from .recipe_cache import RecipeCache
from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner
from .civitai_client import CivitaiClient
from ..utils.utils import fuzzy_match
import sys
@@ -18,11 +18,22 @@ class RecipeScanner:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None):
"""Get singleton instance of RecipeScanner"""
async with cls._lock:
if cls._instance is None:
if not lora_scanner:
# Get lora scanner from service registry if not provided
lora_scanner = await ServiceRegistry.get_lora_scanner()
cls._instance = cls(lora_scanner)
return cls._instance
def __new__(cls, lora_scanner: Optional[LoraScanner] = None):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._lora_scanner = lora_scanner
cls._instance._civitai_client = CivitaiClient()
cls._instance._civitai_client = None # Will be lazily initialized
return cls._instance
def __init__(self, lora_scanner: Optional[LoraScanner] = None):
@@ -36,6 +47,12 @@ class RecipeScanner:
self._lora_scanner = lora_scanner
self._initialized = True
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -306,10 +323,13 @@ class RecipeScanner:
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API"""
try:
if not self._civitai_client:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
if not version_info or not version_info.get('files'):
logger.debug(f"No files found in version info for ID: {model_version_id}")
@@ -329,10 +349,12 @@ class RecipeScanner:
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
"""Get model version name from Civitai API"""
try:
if not self._civitai_client:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
if version_info and 'name' in version_info:
return version_info['name']

View File

@@ -0,0 +1,124 @@
import asyncio
import logging
from typing import Optional, Dict, Any, TypeVar, Type
logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.warning(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
@classmethod
async def get_lora_monitor(cls):
"""Get the LoraFileMonitor instance"""
from .file_monitor import LoraFileMonitor
monitor = await cls.get_service("lora_monitor")
if monitor is None:
monitor = await LoraFileMonitor.get_instance()
await cls.register_service("lora_monitor", monitor)
return monitor
@classmethod
async def get_checkpoint_monitor(cls):
"""Get the CheckpointFileMonitor instance"""
from .file_monitor import CheckpointFileMonitor
monitor = await cls.get_service("checkpoint_monitor")
if monitor is None:
monitor = await CheckpointFileMonitor.get_instance()
await cls.register_service("checkpoint_monitor", monitor)
return monitor
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
# We'll let DownloadManager.get_instance handle file_monitor parameter
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager