checkpoint

This commit is contained in:
Will Miao
2025-04-11 20:22:12 +08:00
parent 1db49a4dd4
commit 0618541527
13 changed files with 793 additions and 276 deletions

View File

@@ -19,22 +19,33 @@ from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.civitai_client = None # Will be initialized in setup_routes
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.download_manager = await ServiceRegistry.get_download_manager()
@classmethod
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls(monitor)
routes = cls()
# Schedule service initialization on app startup
app.on_startup.append(lambda _: routes.initialize_services())
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
@@ -63,19 +74,28 @@ class ApiRoutes:
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
@@ -231,6 +251,9 @@ class ApiRoutes:
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
@@ -312,6 +335,9 @@ class ApiRoutes:
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
@@ -320,6 +346,12 @@ class ApiRoutes:
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
versions = await self.civitai_client.get_model_versions(model_id)
if not versions:
@@ -353,9 +385,12 @@ class ApiRoutes:
async def get_civitai_model(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID or hash"""
try:
model_version_id = request.match_info['modelVersionId']
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_version_id = request.match_info.get('modelVersionId')
if not model_version_id:
hash = request.match_info['hash']
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
@@ -370,6 +405,9 @@ class ApiRoutes:
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
try:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
@@ -447,6 +485,9 @@ class ApiRoutes:
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
@@ -485,12 +526,17 @@ class ApiRoutes:
@classmethod
async def cleanup(cls):
"""Add cleanup method for application shutdown"""
if hasattr(cls, '_instance'):
await cls._instance.civitai_client.close()
# Now we don't need to store an instance, as services are managed by ServiceRegistry
civitai_client = await ServiceRegistry.get_civitai_client()
if civitai_client:
await civitai_client.close()
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
@@ -536,6 +582,9 @@ class ApiRoutes:
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -574,6 +623,9 @@ class ApiRoutes:
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -619,6 +671,9 @@ class ApiRoutes:
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
@@ -677,6 +732,9 @@ class ApiRoutes:
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
@@ -736,6 +794,9 @@ class ApiRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -761,6 +822,9 @@ class ApiRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -785,6 +849,12 @@ class ApiRoutes:
async def rename_lora(self, request: web.Request) -> web.Response:
"""Handle renaming a LoRA file and its associated files"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
file_path = data.get('file_path')
new_file_name = data.get('new_file_name')
@@ -891,7 +961,7 @@ class ApiRoutes:
# Update recipe files and cache if hash is available
if hash_value:
recipe_scanner = RecipeScanner(self.scanner)
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")

View File

@@ -11,6 +11,7 @@ from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner
from ..services.download_manager import DownloadManager
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
@@ -21,16 +22,24 @@ class CheckpointsRoutes:
"""API routes for checkpoint management"""
def __init__(self):
self.scanner = CheckpointScanner()
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = DownloadManager()
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
@@ -488,10 +497,9 @@ class CheckpointsRoutes:
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Initialize DownloadManager with the file monitor if the scanner has one
if not hasattr(self, 'download_manager') or self.download_manager is None:
file_monitor = getattr(self.scanner, 'file_monitor', None)
self.download_manager = DownloadManager(file_monitor)
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
# Use the common download handler with model_type="checkpoint"
return await ModelRouteUtils.handle_download_model(
@@ -503,6 +511,9 @@ class CheckpointsRoutes:
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,

View File

@@ -6,7 +6,8 @@ import logging
from ..services.lora_scanner import LoraScanner
from ..services.recipe_scanner import RecipeScanner
from ..config import config
from ..services.settings_manager import settings # Add this import
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
@@ -15,13 +16,24 @@ class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
def __init__(self):
self.scanner = LoraScanner()
self.recipe_scanner = RecipeScanner(self.scanner)
# Initialize service references as None, will be set during async init
self.scanner = None
self.recipe_scanner = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
logger.info("LoraRoutes: Retrieved LoraScanner from ServiceRegistry")
if self.recipe_scanner is None:
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
logger.info("LoraRoutes: Retrieved RecipeScanner from ServiceRegistry")
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
@@ -58,7 +70,10 @@ class LoraRoutes:
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# 检查缓存初始化状态根据initialize_in_background的工作方式调整判断逻辑
# Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
is_initializing = (
self.scanner._cache is None or
len(self.scanner._cache.raw_data) == 0 or
@@ -66,30 +81,29 @@ class LoraRoutes:
)
if is_initializing:
# 如果正在初始化,返回一个只包含加载提示的页面
# If still initializing, return loading page
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Loras page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
# Normal flow - get data from initialized cache
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
@@ -114,7 +128,10 @@ class LoraRoutes:
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# 检查缓存初始化状态与handle_loras_page保持一致的逻辑
# Ensure services are initialized
await self.init_services()
# Check if the RecipeScanner is initializing
is_initializing = (
self.recipe_scanner._cache is None or
len(self.recipe_scanner._cache.raw_data) == 0 or
@@ -183,5 +200,13 @@ class LoraRoutes:
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
# Register routes
app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()

View File

@@ -16,6 +16,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config
from ..workflow.parser import WorkflowParser
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__)
@@ -23,13 +24,24 @@ class RecipeRoutes:
"""API route handlers for Recipe management"""
def __init__(self):
self.recipe_scanner = RecipeScanner(LoraScanner())
self.civitai_client = CivitaiClient()
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
self.parser = WorkflowParser()
# Pre-warm the cache
self._init_cache_task = None
async def init_services(self):
"""Initialize services from ServiceRegistry"""
if self.recipe_scanner is None:
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
logger.info("RecipeRoutes: Retrieved RecipeScanner from ServiceRegistry")
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
logger.info("RecipeRoutes: Retrieved CivitaiClient from ServiceRegistry")
@classmethod
def setup_routes(cls, app: web.Application):
"""Register API routes"""
@@ -68,7 +80,10 @@ class RecipeRoutes:
async def _init_cache(self, app):
"""Initialize cache on startup"""
try:
# First, ensure the lora scanner is fully initialized
# Initialize services first
await self.init_services()
# Now that services are initialized, get the lora scanner
lora_scanner = self.recipe_scanner._lora_scanner
# Get lora cache to ensure it's initialized
@@ -86,6 +101,9 @@ class RecipeRoutes:
async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get query parameters with defaults
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
@@ -155,6 +173,9 @@ class RecipeRoutes:
async def get_recipe_detail(self, request: web.Request) -> web.Response:
"""Get detailed information about a specific recipe"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Use the new get_recipe_by_id method from recipe_scanner
@@ -208,6 +229,9 @@ class RecipeRoutes:
"""Analyze an uploaded image or URL for recipe metadata"""
temp_path = None
try:
# Ensure services are initialized
await self.init_services()
# Check if request contains multipart data (image) or JSON data (url)
content_type = request.headers.get('Content-Type', '')
@@ -326,6 +350,9 @@ class RecipeRoutes:
async def save_recipe(self, request: web.Request) -> web.Response:
"""Save a recipe to the recipes folder"""
try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Process form data
@@ -527,6 +554,9 @@ class RecipeRoutes:
async def delete_recipe(self, request: web.Request) -> web.Response:
"""Delete a recipe by ID"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get recipes directory
@@ -574,6 +604,9 @@ class RecipeRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Get top tags used in recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get limit parameter with default
limit = int(request.query.get('limit', '20'))
@@ -606,6 +639,9 @@ class RecipeRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
@@ -634,6 +670,9 @@ class RecipeRoutes:
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
@@ -693,6 +732,9 @@ class RecipeRoutes:
async def download_shared_recipe(self, request: web.Request) -> web.Response:
"""Serve a processed recipe image for download"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Check if we have this shared recipe
@@ -749,6 +791,9 @@ class RecipeRoutes:
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
"""Save a recipe from the LoRAs widget"""
try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Process form data
@@ -923,6 +968,9 @@ class RecipeRoutes:
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
"""Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
@@ -1003,6 +1051,9 @@ class RecipeRoutes:
async def update_recipe(self, request: web.Request) -> web.Response:
"""Update recipe metadata (name and tags)"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
data = await request.json()
@@ -1030,6 +1081,9 @@ class RecipeRoutes:
async def reconnect_lora(self, request: web.Request) -> web.Response:
"""Reconnect a deleted LoRA in a recipe to a local LoRA file"""
try:
# Ensure services are initialized
await self.init_services()
# Parse request data
data = await request.json()
@@ -1140,6 +1194,9 @@ class RecipeRoutes:
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
"""Get recipes that use a specific Lora"""
try:
# Ensure services are initialized
await self.init_services()
lora_hash = request.query.get('hash')
# Hash is required