mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
refactor: Enhance checkpoint download functionality with new modal and manager integration
This commit is contained in:
@@ -3,12 +3,14 @@ import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
@@ -24,6 +26,8 @@ class CheckpointsRoutes:
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = DownloadManager()
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
@@ -34,11 +38,13 @@ class CheckpointsRoutes:
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||
|
||||
async def get_checkpoints(self, request):
|
||||
"""Get paginated checkpoint data"""
|
||||
@@ -478,3 +484,33 @@ class CheckpointsRoutes:
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement for checkpoints"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint download request"""
|
||||
async with self._download_lock:
|
||||
# Initialize DownloadManager with the file monitor if the scanner has one
|
||||
if not hasattr(self, 'download_manager') or self.download_manager is None:
|
||||
file_monitor = getattr(self.scanner, 'file_monitor', None)
|
||||
self.download_manager = DownloadManager(file_monitor)
|
||||
|
||||
# Use the common download handler with model_type="checkpoint"
|
||||
return await ModelRouteUtils.handle_download_model(
|
||||
request=request,
|
||||
download_manager=self.download_manager,
|
||||
model_type="checkpoint"
|
||||
)
|
||||
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from typing import Optional, Dict
|
||||
from .civitai_client import CivitaiClient
|
||||
from .file_monitor import LoraFileMonitor
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
|
||||
@@ -20,7 +20,22 @@ class DownloadManager:
|
||||
|
||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
relative_path: str = '', progress_callback=None) -> Dict:
|
||||
relative_path: str = '', progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Download model from Civitai
|
||||
|
||||
Args:
|
||||
download_url: Direct download URL for the model
|
||||
model_hash: SHA256 hash of the model
|
||||
model_version_id: Civitai model version ID
|
||||
save_dir: Directory to save the model to
|
||||
relative_path: Relative path within save_dir
|
||||
progress_callback: Callback function for progress updates
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
"""
|
||||
try:
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
@@ -46,7 +61,7 @@ class DownloadManager:
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
# Check if this is an early access LoRA
|
||||
# Check if this is an early access model
|
||||
if version_info.get('earlyAccessEndsAt'):
|
||||
early_access_date = version_info.get('earlyAccessEndsAt', '')
|
||||
# Convert to a readable date if possible
|
||||
@@ -54,12 +69,12 @@ class DownloadManager:
|
||||
from datetime import datetime
|
||||
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
|
||||
formatted_date = date_obj.strftime('%Y-%m-%d')
|
||||
early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). "
|
||||
early_access_msg = f"This model requires early access payment (until {formatted_date}). "
|
||||
except:
|
||||
early_access_msg = "This LoRA requires early access payment. "
|
||||
early_access_msg = "This model requires early access payment. "
|
||||
|
||||
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
|
||||
logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}")
|
||||
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
|
||||
|
||||
# We'll still try to download, but log a warning and prepare for potential failure
|
||||
if progress_callback:
|
||||
@@ -69,26 +84,32 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
await progress_callback(0)
|
||||
|
||||
# 2. 获取文件信息
|
||||
# 2. Get file information
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||
if not file_info:
|
||||
return {'success': False, 'error': 'No primary file found in metadata'}
|
||||
|
||||
# 3. 准备下载
|
||||
# 3. Prepare download
|
||||
file_name = file_info['name']
|
||||
save_path = os.path.join(save_dir, file_name)
|
||||
file_size = file_info.get('sizeKB', 0) * 1024
|
||||
|
||||
# 4. 通知文件监控系统 - 使用规范化路径和文件大小
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
save_path.replace(os.sep, '/'),
|
||||
file_size
|
||||
)
|
||||
# 4. Notify file monitor - use normalized path and file size
|
||||
if self.file_monitor and self.file_monitor.handler:
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
save_path.replace(os.sep, '/'),
|
||||
file_size
|
||||
)
|
||||
|
||||
# 5. 准备元数据
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
# 5. Prepare metadata based on model type
|
||||
if model_type == "checkpoint":
|
||||
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
||||
else:
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||
|
||||
# 5.1 获取并更新模型标签和描述信息
|
||||
# 5.1 Get and update model tags and description
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
|
||||
@@ -98,14 +119,15 @@ class DownloadManager:
|
||||
if model_metadata.get("description"):
|
||||
metadata.modelDescription = model_metadata.get("description", "")
|
||||
|
||||
# 6. 开始下载流程
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
download_url=file_info.get('downloadUrl', ''),
|
||||
save_dir=save_dir,
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -119,8 +141,9 @@ class DownloadManager:
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata: LoraMetadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None) -> Dict:
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
save_path = metadata.file_path
|
||||
@@ -201,15 +224,21 @@ class DownloadManager:
|
||||
os.remove(path)
|
||||
return {'success': False, 'error': result}
|
||||
|
||||
# 4. 更新文件信息(大小和修改时间)
|
||||
# 4. Update file information (size and modified time)
|
||||
metadata.update_file_info(save_path)
|
||||
|
||||
# 5. 最终更新元数据
|
||||
# 5. Final metadata update
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 6. update lora cache
|
||||
cache = await self.file_monitor.scanner.get_cached_data()
|
||||
# 6. Update cache based on model type
|
||||
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
|
||||
cache = await self.file_monitor.checkpoint_scanner.get_cached_data()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
else:
|
||||
cache = await self.file_monitor.scanner.get_cached_data()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = relative_path
|
||||
cache.raw_data.append(metadata_dict)
|
||||
@@ -218,11 +247,11 @@ class DownloadManager:
|
||||
all_folders.add(relative_path)
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Update the hash index with the new LoRA entry
|
||||
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Update the hash index with the new LoRA entry
|
||||
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
# Update the hash index with the new model entry
|
||||
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
|
||||
self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
else:
|
||||
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Report 100% completion
|
||||
if progress_callback:
|
||||
|
||||
@@ -9,6 +9,7 @@ from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..config import config
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..services.download_manager import DownloadManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -421,4 +422,82 @@ class ModelRouteUtils:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
|
||||
"""Handle model download request
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
download_manager: Instance of DownloadManager
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
async def progress_callback(progress):
|
||||
from ..services.websocket_manager import ws_manager
|
||||
await ws_manager.broadcast({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
|
||||
# Use the correct root directory based on model type
|
||||
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
|
||||
save_dir = data.get(root_key)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=save_dir,
|
||||
relative_path=data.get('relative_path', ''),
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401, # Use 401 status code to match Civitai's response
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
)
|
||||
|
||||
logger.error(f"Error downloading {model_type}: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
Reference in New Issue
Block a user