checkpoint

This commit is contained in:
Will Miao
2025-04-11 20:22:12 +08:00
parent 1db49a4dd4
commit 0618541527
13 changed files with 793 additions and 276 deletions

View File

@@ -19,22 +19,33 @@ from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.civitai_client = None # Will be initialized in setup_routes
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.download_manager = await ServiceRegistry.get_download_manager()
@classmethod
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls(monitor)
routes = cls()
# Schedule service initialization on app startup
app.on_startup.append(lambda _: routes.initialize_services())
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
@@ -63,19 +74,28 @@ class ApiRoutes:
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement request"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
@@ -231,6 +251,9 @@ class ApiRoutes:
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
@@ -312,6 +335,9 @@ class ApiRoutes:
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
@@ -320,6 +346,12 @@ class ApiRoutes:
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
versions = await self.civitai_client.get_model_versions(model_id)
if not versions:
@@ -353,9 +385,12 @@ class ApiRoutes:
async def get_civitai_model(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID or hash"""
try:
model_version_id = request.match_info['modelVersionId']
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_version_id = request.match_info.get('modelVersionId')
if not model_version_id:
hash = request.match_info['hash']
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
@@ -370,6 +405,9 @@ class ApiRoutes:
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
try:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
@@ -447,6 +485,9 @@ class ApiRoutes:
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
@@ -485,12 +526,17 @@ class ApiRoutes:
@classmethod
async def cleanup(cls):
"""Add cleanup method for application shutdown"""
if hasattr(cls, '_instance'):
await cls._instance.civitai_client.close()
# Now we don't need to store an instance, as services are managed by ServiceRegistry
civitai_client = await ServiceRegistry.get_civitai_client()
if civitai_client:
await civitai_client.close()
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
@@ -536,6 +582,9 @@ class ApiRoutes:
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -574,6 +623,9 @@ class ApiRoutes:
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -619,6 +671,9 @@ class ApiRoutes:
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
@@ -677,6 +732,9 @@ class ApiRoutes:
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
@@ -736,6 +794,9 @@ class ApiRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -761,6 +822,9 @@ class ApiRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -785,6 +849,12 @@ class ApiRoutes:
async def rename_lora(self, request: web.Request) -> web.Response:
"""Handle renaming a LoRA file and its associated files"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
file_path = data.get('file_path')
new_file_name = data.get('new_file_name')
@@ -891,7 +961,7 @@ class ApiRoutes:
# Update recipe files and cache if hash is available
if hash_value:
recipe_scanner = RecipeScanner(self.scanner)
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")