mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(example-images): minimize async lock contention by moving I/O outside critical sections
- Extract progress file loading to async methods to run in executor - Refactor start_download to reduce lock time by pre-loading data before entering lock - Improve check_pending_models efficiency with single-pass model collection and async loading - Add type hints to get_status method - Add tests for download task callback execution and error handling
This commit is contained in:
@@ -121,100 +121,65 @@ class DownloadManager:
|
||||
async def start_download(self, options: dict):
|
||||
"""Start downloading example images for models."""
|
||||
|
||||
# Step 1: Parse options (fast, non-blocking)
|
||||
data = options or {}
|
||||
auto_mode = data.get("auto_mode", False)
|
||||
optimize = data.get("optimize", True)
|
||||
model_types = data.get("model_types", ["lora", "checkpoint"])
|
||||
delay = float(data.get("delay", 0.2))
|
||||
force = data.get("force", False)
|
||||
|
||||
# Step 2: Validate configuration (fast lookup)
|
||||
settings_manager = get_settings_manager()
|
||||
base_path = settings_manager.get("example_images_path")
|
||||
|
||||
if not base_path:
|
||||
error_msg = "Example images path not configured in settings"
|
||||
if auto_mode:
|
||||
logger.debug(error_msg)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Example images path not configured, skipping auto download",
|
||||
}
|
||||
raise DownloadConfigurationError(error_msg)
|
||||
|
||||
active_library = settings_manager.get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
if not output_dir:
|
||||
raise DownloadConfigurationError(
|
||||
"Example images path not configured in settings"
|
||||
)
|
||||
|
||||
# Step 3: Load progress file (I/O operation, done outside lock)
|
||||
processed_models = set()
|
||||
failed_models = set()
|
||||
|
||||
try:
|
||||
progress_file, processed_models, failed_models = await self._load_progress_file(output_dir)
|
||||
logger.debug(
|
||||
"Loaded previous progress, %s models already processed, %s models marked as failed",
|
||||
len(processed_models),
|
||||
len(failed_models),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load progress file: {e}")
|
||||
# Continue with empty sets
|
||||
|
||||
# Step 4: Quick state check and update (minimal lock time)
|
||||
async with self._state_lock:
|
||||
if self._is_downloading:
|
||||
raise DownloadInProgressError(self._progress.snapshot())
|
||||
|
||||
try:
|
||||
data = options or {}
|
||||
auto_mode = data.get("auto_mode", False)
|
||||
optimize = data.get("optimize", True)
|
||||
model_types = data.get("model_types", ["lora", "checkpoint"])
|
||||
delay = float(data.get("delay", 0.2))
|
||||
force = data.get("force", False)
|
||||
|
||||
settings_manager = get_settings_manager()
|
||||
base_path = settings_manager.get("example_images_path")
|
||||
|
||||
if not base_path:
|
||||
error_msg = "Example images path not configured in settings"
|
||||
if auto_mode:
|
||||
logger.debug(error_msg)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Example images path not configured, skipping auto download",
|
||||
}
|
||||
raise DownloadConfigurationError(error_msg)
|
||||
|
||||
active_library = get_settings_manager().get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
if not output_dir:
|
||||
raise DownloadConfigurationError(
|
||||
"Example images path not configured in settings"
|
||||
)
|
||||
|
||||
# Reset progress with loaded data
|
||||
self._progress.reset()
|
||||
self._progress["processed_models"] = processed_models
|
||||
self._progress["failed_models"] = failed_models
|
||||
self._stop_requested = False
|
||||
self._progress["status"] = "running"
|
||||
self._progress["start_time"] = time.time()
|
||||
self._progress["end_time"] = None
|
||||
|
||||
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||
progress_source = progress_file
|
||||
if uses_library_scoped_folders():
|
||||
legacy_root = (
|
||||
get_settings_manager().get("example_images_path") or ""
|
||||
)
|
||||
legacy_progress = (
|
||||
os.path.join(legacy_root, ".download_progress.json")
|
||||
if legacy_root
|
||||
else ""
|
||||
)
|
||||
if (
|
||||
legacy_progress
|
||||
and os.path.exists(legacy_progress)
|
||||
and not os.path.exists(progress_file)
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
shutil.move(legacy_progress, progress_file)
|
||||
logger.info(
|
||||
"Migrated legacy download progress file '%s' to '%s'",
|
||||
legacy_progress,
|
||||
progress_file,
|
||||
)
|
||||
except OSError as exc:
|
||||
logger.warning(
|
||||
"Failed to migrate download progress file from '%s' to '%s': %s",
|
||||
legacy_progress,
|
||||
progress_file,
|
||||
exc,
|
||||
)
|
||||
progress_source = legacy_progress
|
||||
|
||||
if os.path.exists(progress_source):
|
||||
try:
|
||||
with open(progress_source, "r", encoding="utf-8") as f:
|
||||
saved_progress = json.load(f)
|
||||
self._progress["processed_models"] = set(
|
||||
saved_progress.get("processed_models", [])
|
||||
)
|
||||
self._progress["failed_models"] = set(
|
||||
saved_progress.get("failed_models", [])
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded previous progress, %s models already processed, %s models marked as failed",
|
||||
len(self._progress["processed_models"]),
|
||||
len(self._progress["failed_models"]),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load progress file: {e}")
|
||||
self._progress["processed_models"] = set()
|
||||
self._progress["failed_models"] = set()
|
||||
else:
|
||||
self._progress["processed_models"] = set()
|
||||
self._progress["failed_models"] = set()
|
||||
|
||||
self._is_downloading = True
|
||||
snapshot = self._progress.snapshot()
|
||||
|
||||
@@ -268,7 +233,7 @@ class DownloadManager:
|
||||
except Exception as save_error:
|
||||
logger.error(f"Failed to save progress after task failure: {save_error}")
|
||||
|
||||
async def get_status(self, request):
|
||||
async def get_status(self, request) -> dict:
|
||||
"""Get the current status of example images download."""
|
||||
|
||||
return {
|
||||
@@ -277,6 +242,87 @@ class DownloadManager:
|
||||
"status": self._progress.snapshot(),
|
||||
}
|
||||
|
||||
async def _load_progress_file(self, output_dir: str) -> tuple[str, set, set]:
|
||||
"""Load progress file from disk. Returns (progress_file_path, processed_models, failed_models).
|
||||
|
||||
This is a separate async method to allow running in executor to avoid blocking event loop.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._load_progress_file_sync, output_dir
|
||||
)
|
||||
|
||||
def _load_progress_file_sync(self, output_dir: str) -> tuple[str, set, set]:
|
||||
"""Synchronous implementation of progress file loading."""
|
||||
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||
progress_source = progress_file
|
||||
|
||||
# Handle legacy migration if needed
|
||||
if uses_library_scoped_folders():
|
||||
legacy_root = get_settings_manager().get("example_images_path") or ""
|
||||
legacy_progress = (
|
||||
os.path.join(legacy_root, ".download_progress.json")
|
||||
if legacy_root
|
||||
else ""
|
||||
)
|
||||
if (
|
||||
legacy_progress
|
||||
and os.path.exists(legacy_progress)
|
||||
and not os.path.exists(progress_file)
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
shutil.move(legacy_progress, progress_file)
|
||||
logger.info(
|
||||
"Migrated legacy download progress file '%s' to '%s'",
|
||||
legacy_progress,
|
||||
progress_file,
|
||||
)
|
||||
except OSError as exc:
|
||||
logger.warning(
|
||||
"Failed to migrate download progress file from '%s' to '%s': %s",
|
||||
legacy_progress,
|
||||
progress_file,
|
||||
exc,
|
||||
)
|
||||
progress_source = legacy_progress
|
||||
|
||||
processed_models = set()
|
||||
failed_models = set()
|
||||
|
||||
if os.path.exists(progress_source):
|
||||
try:
|
||||
with open(progress_source, "r", encoding="utf-8") as f:
|
||||
saved_progress = json.load(f)
|
||||
processed_models = set(saved_progress.get("processed_models", []))
|
||||
failed_models = set(saved_progress.get("failed_models", []))
|
||||
except Exception:
|
||||
# Return empty sets on error
|
||||
pass
|
||||
|
||||
return progress_file, processed_models, failed_models
|
||||
|
||||
def _load_progress_sets_sync(self, progress_file: str) -> tuple[set, set]:
|
||||
"""Load only the processed and failed model sets from progress file.
|
||||
|
||||
This is a lighter version for quick checks without legacy migration.
|
||||
Returns (processed_models, failed_models).
|
||||
"""
|
||||
processed_models = set()
|
||||
failed_models = set()
|
||||
|
||||
if os.path.exists(progress_file):
|
||||
try:
|
||||
with open(progress_file, "r", encoding="utf-8") as f:
|
||||
saved_progress = json.load(f)
|
||||
processed_models = set(saved_progress.get("processed_models", []))
|
||||
failed_models = set(saved_progress.get("failed_models", []))
|
||||
except Exception:
|
||||
# Return empty sets on error
|
||||
pass
|
||||
|
||||
return processed_models, failed_models
|
||||
|
||||
async def check_pending_models(self, model_types: list[str]) -> dict:
|
||||
"""Quickly check how many models need example images downloaded.
|
||||
|
||||
@@ -320,62 +366,49 @@ class DownloadManager:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(("embedding", embedding_scanner))
|
||||
|
||||
# Load progress file to check processed models
|
||||
# Load progress file to check processed models (async to avoid blocking)
|
||||
settings_manager = get_settings_manager()
|
||||
active_library = settings_manager.get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
|
||||
|
||||
processed_models: set[str] = set()
|
||||
failed_models: set[str] = set()
|
||||
|
||||
|
||||
if output_dir:
|
||||
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||
if os.path.exists(progress_file):
|
||||
try:
|
||||
with open(progress_file, "r", encoding="utf-8") as f:
|
||||
saved_progress = json.load(f)
|
||||
processed_models = set(saved_progress.get("processed_models", []))
|
||||
failed_models = set(saved_progress.get("failed_models", []))
|
||||
except Exception:
|
||||
pass # Ignore progress file errors for quick check
|
||||
loop = asyncio.get_event_loop()
|
||||
processed_models, failed_models = await loop.run_in_executor(
|
||||
None, self._load_progress_sets_sync, progress_file
|
||||
)
|
||||
|
||||
# Count models
|
||||
# Collect all models and count in a single pass per scanner
|
||||
total_models = 0
|
||||
models_with_hash = 0
|
||||
|
||||
all_models_with_hash: list[tuple[str, str]] = [] # (hash, name) pairs
|
||||
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
total_models += 1
|
||||
if model.get("sha256"):
|
||||
models_with_hash += 1
|
||||
|
||||
# Calculate pending count
|
||||
# A model is pending if it has a hash and is not in processed_models
|
||||
# We also exclude failed_models unless force mode would be used
|
||||
pending_count = models_with_hash - len(processed_models.intersection(
|
||||
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
|
||||
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
|
||||
))
|
||||
|
||||
# More accurate pending count: check which models actually need processing
|
||||
pending_hashes = set()
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
raw_hash = model.get("sha256")
|
||||
if not raw_hash:
|
||||
continue
|
||||
model_hash = raw_hash.lower()
|
||||
if model_hash not in processed_models:
|
||||
# Check if model folder exists with files
|
||||
model_dir = ExampleImagePathResolver.get_model_folder(
|
||||
model_hash, active_library
|
||||
)
|
||||
if not _model_directory_has_files(model_dir):
|
||||
pending_hashes.add(model_hash)
|
||||
if raw_hash:
|
||||
model_hash = raw_hash.lower()
|
||||
all_models_with_hash.append((model_hash, model.get("model_name", "Unknown")))
|
||||
|
||||
models_with_hash = len(all_models_with_hash)
|
||||
|
||||
# Calculate pending count: check which models actually need processing
|
||||
# A model is pending if it has a hash, is not in processed_models,
|
||||
# and its folder doesn't exist or is empty
|
||||
pending_hashes = set()
|
||||
for model_hash, model_name in all_models_with_hash:
|
||||
if model_hash not in processed_models:
|
||||
# Check if model folder exists with files
|
||||
model_dir = ExampleImagePathResolver.get_model_folder(
|
||||
model_hash, active_library
|
||||
)
|
||||
if not _model_directory_has_files(model_dir):
|
||||
pending_hashes.add(model_hash)
|
||||
|
||||
pending_count = len(pending_hashes)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user