mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
checkpoint
This commit is contained in:
@@ -19,22 +19,33 @@ from .update_routes import UpdateRoutes
|
||||
from ..services.recipe_scanner import RecipeScanner
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiRoutes:
|
||||
"""API route handlers for LoRA management"""
|
||||
|
||||
def __init__(self, file_monitor: LoraFileMonitor):
|
||||
self.scanner = LoraScanner()
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.download_manager = DownloadManager(file_monitor)
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.civitai_client = None # Will be initialized in setup_routes
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls(monitor)
|
||||
routes = cls()
|
||||
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: routes.initialize_services())
|
||||
|
||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
@@ -63,19 +74,28 @@ class ApiRoutes:
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def get_loras(self, request: web.Request) -> web.Response:
|
||||
"""Handle paginated LoRA data request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = int(request.query.get('page_size', '20'))
|
||||
@@ -231,6 +251,9 @@ class ApiRoutes:
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all loras in the background"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
@@ -312,6 +335,9 @@ class ApiRoutes:
|
||||
|
||||
async def get_folders(self, request: web.Request) -> web.Response:
|
||||
"""Get all folders in the cache"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
return web.json_response({
|
||||
'folders': cache.folders
|
||||
@@ -320,6 +346,12 @@ class ApiRoutes:
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
versions = await self.civitai_client.get_model_versions(model_id)
|
||||
if not versions:
|
||||
@@ -353,9 +385,12 @@ class ApiRoutes:
|
||||
async def get_civitai_model(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID or hash"""
|
||||
try:
|
||||
model_version_id = request.match_info['modelVersionId']
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
if not model_version_id:
|
||||
hash = request.match_info['hash']
|
||||
hash = request.match_info.get('hash')
|
||||
model = await self.civitai_client.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
|
||||
@@ -370,6 +405,9 @@ class ApiRoutes:
|
||||
async def download_lora(self, request: web.Request) -> web.Response:
|
||||
async with self._download_lock:
|
||||
try:
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
@@ -447,6 +485,9 @@ class ApiRoutes:
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model move request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
|
||||
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
|
||||
@@ -485,12 +526,17 @@ class ApiRoutes:
|
||||
@classmethod
|
||||
async def cleanup(cls):
|
||||
"""Add cleanup method for application shutdown"""
|
||||
if hasattr(cls, '_instance'):
|
||||
await cls._instance.civitai_client.close()
|
||||
# Now we don't need to store an instance, as services are managed by ServiceRegistry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
@@ -536,6 +582,9 @@ class ApiRoutes:
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
@@ -574,6 +623,9 @@ class ApiRoutes:
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
@@ -619,6 +671,9 @@ class ApiRoutes:
|
||||
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk model move request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
|
||||
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
|
||||
@@ -677,6 +732,9 @@ class ApiRoutes:
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
@@ -736,6 +794,9 @@ class ApiRoutes:
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
@@ -761,6 +822,9 @@ class ApiRoutes:
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
@@ -785,6 +849,12 @@ class ApiRoutes:
|
||||
async def rename_lora(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a LoRA file and its associated files"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
new_file_name = data.get('new_file_name')
|
||||
@@ -891,7 +961,7 @@ class ApiRoutes:
|
||||
|
||||
# Update recipe files and cache if hash is available
|
||||
if hash_value:
|
||||
recipe_scanner = RecipeScanner(self.scanner)
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
|
||||
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user