mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
checkpoint
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
from .civitai_client import CivitaiClient
|
||||
from .file_monitor import LoraFileMonitor
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -14,9 +15,46 @@ import tempfile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.file_monitor = file_monitor
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of DownloadManager"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self._civitai_client = None # Will be lazily initialized
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
if self._civitai_client is None:
|
||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
return self._civitai_client
|
||||
|
||||
async def _get_lora_monitor(self):
|
||||
"""Get the lora file monitor from registry"""
|
||||
return await ServiceRegistry.get_lora_monitor()
|
||||
|
||||
async def _get_checkpoint_monitor(self):
|
||||
"""Get the checkpoint file monitor from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_monitor()
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
return await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
async def _get_checkpoint_scanner(self):
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
@@ -43,19 +81,22 @@ class DownloadManager:
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Get civitai client
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = None
|
||||
|
||||
if download_url:
|
||||
# Extract version ID from download URL
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info = await self.civitai_client.get_model_version_info(version_id)
|
||||
version_info = await civitai_client.get_model_version_info(version_id)
|
||||
elif model_version_id:
|
||||
# Use model version ID directly
|
||||
version_info = await self.civitai_client.get_model_version_info(model_version_id)
|
||||
version_info = await civitai_client.get_model_version_info(model_version_id)
|
||||
elif model_hash:
|
||||
# Get model by hash
|
||||
version_info = await self.civitai_client.get_model_by_hash(model_hash)
|
||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
||||
|
||||
|
||||
if not version_info:
|
||||
@@ -95,8 +136,9 @@ class DownloadManager:
|
||||
file_size = file_info.get('sizeKB', 0) * 1024
|
||||
|
||||
# 4. Notify file monitor - use normalized path and file size
|
||||
if self.file_monitor and self.file_monitor.handler:
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor()
|
||||
if file_monitor and file_monitor.handler:
|
||||
file_monitor.handler.add_ignore_path(
|
||||
save_path.replace(os.sep, '/'),
|
||||
file_size
|
||||
)
|
||||
@@ -112,7 +154,7 @@ class DownloadManager:
|
||||
# 5.1 Get and update model tags and description
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
|
||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
if model_metadata.get("tags"):
|
||||
metadata.tags = model_metadata.get("tags", [])
|
||||
@@ -146,6 +188,7 @@ class DownloadManager:
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
civitai_client = await self._get_civitai_client()
|
||||
save_path = metadata.file_path
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
@@ -165,7 +208,7 @@ class DownloadManager:
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
|
||||
# Download video directly
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
@@ -176,7 +219,7 @@ class DownloadManager:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Download the original image to temp path
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], temp_path):
|
||||
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
|
||||
# Optimize and convert to WebP
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
@@ -210,7 +253,7 @@ class DownloadManager:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
# Download model file with progress tracking
|
||||
success, result = await self.civitai_client._download_file(
|
||||
success, result = await civitai_client._download_file(
|
||||
download_url,
|
||||
save_dir,
|
||||
os.path.basename(save_path),
|
||||
@@ -232,13 +275,14 @@ class DownloadManager:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 6. Update cache based on model type
|
||||
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
|
||||
cache = await self.file_monitor.checkpoint_scanner.get_cached_data()
|
||||
if model_type == "checkpoint":
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
else:
|
||||
cache = await self.file_monitor.scanner.get_cached_data()
|
||||
scanner = await self._get_lora_scanner()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = relative_path
|
||||
cache.raw_data.append(metadata_dict)
|
||||
@@ -248,10 +292,7 @@ class DownloadManager:
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Update the hash index with the new model entry
|
||||
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
|
||||
self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
else:
|
||||
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Report 100% completion
|
||||
if progress_callback:
|
||||
|
||||
Reference in New Issue
Block a user