refactor(example-images): minimize async lock contention by moving I/O outside critical sections

- Extract progress file loading to async methods to run in executor
- Refactor start_download to reduce lock time by pre-loading data before entering lock
- Improve check_pending_models efficiency with single-pass model collection and async loading
- Add type hints to get_status method
- Add tests for download task callback execution and error handling
This commit is contained in:
Will Miao
2026-02-11 09:24:00 +08:00
parent 94edde7744
commit 6b1e3f06ed
2 changed files with 305 additions and 125 deletions

View File

@@ -121,100 +121,65 @@ class DownloadManager:
async def start_download(self, options: dict): async def start_download(self, options: dict):
"""Start downloading example images for models.""" """Start downloading example images for models."""
# Step 1: Parse options (fast, non-blocking)
data = options or {}
auto_mode = data.get("auto_mode", False)
optimize = data.get("optimize", True)
model_types = data.get("model_types", ["lora", "checkpoint"])
delay = float(data.get("delay", 0.2))
force = data.get("force", False)
# Step 2: Validate configuration (fast lookup)
settings_manager = get_settings_manager()
base_path = settings_manager.get("example_images_path")
if not base_path:
error_msg = "Example images path not configured in settings"
if auto_mode:
logger.debug(error_msg)
return {
"success": True,
"message": "Example images path not configured, skipping auto download",
}
raise DownloadConfigurationError(error_msg)
active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
if not output_dir:
raise DownloadConfigurationError(
"Example images path not configured in settings"
)
# Step 3: Load progress file (I/O operation, done outside lock)
processed_models = set()
failed_models = set()
try:
progress_file, processed_models, failed_models = await self._load_progress_file(output_dir)
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(processed_models),
len(failed_models),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
# Continue with empty sets
# Step 4: Quick state check and update (minimal lock time)
async with self._state_lock: async with self._state_lock:
if self._is_downloading: if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot()) raise DownloadInProgressError(self._progress.snapshot())
try: try:
data = options or {} # Reset progress with loaded data
auto_mode = data.get("auto_mode", False)
optimize = data.get("optimize", True)
model_types = data.get("model_types", ["lora", "checkpoint"])
delay = float(data.get("delay", 0.2))
force = data.get("force", False)
settings_manager = get_settings_manager()
base_path = settings_manager.get("example_images_path")
if not base_path:
error_msg = "Example images path not configured in settings"
if auto_mode:
logger.debug(error_msg)
return {
"success": True,
"message": "Example images path not configured, skipping auto download",
}
raise DownloadConfigurationError(error_msg)
active_library = get_settings_manager().get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
if not output_dir:
raise DownloadConfigurationError(
"Example images path not configured in settings"
)
self._progress.reset() self._progress.reset()
self._progress["processed_models"] = processed_models
self._progress["failed_models"] = failed_models
self._stop_requested = False self._stop_requested = False
self._progress["status"] = "running" self._progress["status"] = "running"
self._progress["start_time"] = time.time() self._progress["start_time"] = time.time()
self._progress["end_time"] = None self._progress["end_time"] = None
progress_file = os.path.join(output_dir, ".download_progress.json")
progress_source = progress_file
if uses_library_scoped_folders():
legacy_root = (
get_settings_manager().get("example_images_path") or ""
)
legacy_progress = (
os.path.join(legacy_root, ".download_progress.json")
if legacy_root
else ""
)
if (
legacy_progress
and os.path.exists(legacy_progress)
and not os.path.exists(progress_file)
):
try:
os.makedirs(output_dir, exist_ok=True)
shutil.move(legacy_progress, progress_file)
logger.info(
"Migrated legacy download progress file '%s' to '%s'",
legacy_progress,
progress_file,
)
except OSError as exc:
logger.warning(
"Failed to migrate download progress file from '%s' to '%s': %s",
legacy_progress,
progress_file,
exc,
)
progress_source = legacy_progress
if os.path.exists(progress_source):
try:
with open(progress_source, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
self._progress["processed_models"] = set(
saved_progress.get("processed_models", [])
)
self._progress["failed_models"] = set(
saved_progress.get("failed_models", [])
)
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress["processed_models"]),
len(self._progress["failed_models"]),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
self._progress["processed_models"] = set()
self._progress["failed_models"] = set()
else:
self._progress["processed_models"] = set()
self._progress["failed_models"] = set()
self._is_downloading = True self._is_downloading = True
snapshot = self._progress.snapshot() snapshot = self._progress.snapshot()
@@ -268,7 +233,7 @@ class DownloadManager:
except Exception as save_error: except Exception as save_error:
logger.error(f"Failed to save progress after task failure: {save_error}") logger.error(f"Failed to save progress after task failure: {save_error}")
async def get_status(self, request): async def get_status(self, request) -> dict:
"""Get the current status of example images download.""" """Get the current status of example images download."""
return { return {
@@ -277,6 +242,87 @@ class DownloadManager:
"status": self._progress.snapshot(), "status": self._progress.snapshot(),
} }
async def _load_progress_file(self, output_dir: str) -> tuple[str, set, set]:
"""Load progress file from disk. Returns (progress_file_path, processed_models, failed_models).
This is a separate async method to allow running in executor to avoid blocking event loop.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, self._load_progress_file_sync, output_dir
)
def _load_progress_file_sync(self, output_dir: str) -> tuple[str, set, set]:
"""Synchronous implementation of progress file loading."""
progress_file = os.path.join(output_dir, ".download_progress.json")
progress_source = progress_file
# Handle legacy migration if needed
if uses_library_scoped_folders():
legacy_root = get_settings_manager().get("example_images_path") or ""
legacy_progress = (
os.path.join(legacy_root, ".download_progress.json")
if legacy_root
else ""
)
if (
legacy_progress
and os.path.exists(legacy_progress)
and not os.path.exists(progress_file)
):
try:
os.makedirs(output_dir, exist_ok=True)
shutil.move(legacy_progress, progress_file)
logger.info(
"Migrated legacy download progress file '%s' to '%s'",
legacy_progress,
progress_file,
)
except OSError as exc:
logger.warning(
"Failed to migrate download progress file from '%s' to '%s': %s",
legacy_progress,
progress_file,
exc,
)
progress_source = legacy_progress
processed_models = set()
failed_models = set()
if os.path.exists(progress_source):
try:
with open(progress_source, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
processed_models = set(saved_progress.get("processed_models", []))
failed_models = set(saved_progress.get("failed_models", []))
except Exception:
# Return empty sets on error
pass
return progress_file, processed_models, failed_models
def _load_progress_sets_sync(self, progress_file: str) -> tuple[set, set]:
"""Load only the processed and failed model sets from progress file.
This is a lighter version for quick checks without legacy migration.
Returns (processed_models, failed_models).
"""
processed_models = set()
failed_models = set()
if os.path.exists(progress_file):
try:
with open(progress_file, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
processed_models = set(saved_progress.get("processed_models", []))
failed_models = set(saved_progress.get("failed_models", []))
except Exception:
# Return empty sets on error
pass
return processed_models, failed_models
async def check_pending_models(self, model_types: list[str]) -> dict: async def check_pending_models(self, model_types: list[str]) -> dict:
"""Quickly check how many models need example images downloaded. """Quickly check how many models need example images downloaded.
@@ -320,62 +366,49 @@ class DownloadManager:
embedding_scanner = await ServiceRegistry.get_embedding_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(("embedding", embedding_scanner)) scanners.append(("embedding", embedding_scanner))
# Load progress file to check processed models # Load progress file to check processed models (async to avoid blocking)
settings_manager = get_settings_manager() settings_manager = get_settings_manager()
active_library = settings_manager.get_active_library_name() active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library) output_dir = self._resolve_output_dir(active_library)
processed_models: set[str] = set() processed_models: set[str] = set()
failed_models: set[str] = set() failed_models: set[str] = set()
if output_dir: if output_dir:
progress_file = os.path.join(output_dir, ".download_progress.json") progress_file = os.path.join(output_dir, ".download_progress.json")
if os.path.exists(progress_file): loop = asyncio.get_event_loop()
try: processed_models, failed_models = await loop.run_in_executor(
with open(progress_file, "r", encoding="utf-8") as f: None, self._load_progress_sets_sync, progress_file
saved_progress = json.load(f) )
processed_models = set(saved_progress.get("processed_models", []))
failed_models = set(saved_progress.get("failed_models", []))
except Exception:
pass # Ignore progress file errors for quick check
# Count models # Collect all models and count in a single pass per scanner
total_models = 0 total_models = 0
models_with_hash = 0 all_models_with_hash: list[tuple[str, str]] = [] # (hash, name) pairs
for scanner_type, scanner in scanners: for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
if cache and cache.raw_data: if cache and cache.raw_data:
for model in cache.raw_data: for model in cache.raw_data:
total_models += 1 total_models += 1
if model.get("sha256"):
models_with_hash += 1
# Calculate pending count
# A model is pending if it has a hash and is not in processed_models
# We also exclude failed_models unless force mode would be used
pending_count = models_with_hash - len(processed_models.intersection(
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
))
# More accurate pending count: check which models actually need processing
pending_hashes = set()
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
raw_hash = model.get("sha256") raw_hash = model.get("sha256")
if not raw_hash: if raw_hash:
continue model_hash = raw_hash.lower()
model_hash = raw_hash.lower() all_models_with_hash.append((model_hash, model.get("model_name", "Unknown")))
if model_hash not in processed_models:
# Check if model folder exists with files models_with_hash = len(all_models_with_hash)
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, active_library # Calculate pending count: check which models actually need processing
) # A model is pending if it has a hash, is not in processed_models,
if not _model_directory_has_files(model_dir): # and its folder doesn't exist or is empty
pending_hashes.add(model_hash) pending_hashes = set()
for model_hash, model_name in all_models_with_hash:
if model_hash not in processed_models:
# Check if model folder exists with files
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, active_library
)
if not _model_directory_has_files(model_dir):
pending_hashes.add(model_hash)
pending_count = len(pending_hashes) pending_count = len(pending_hashes)

View File

@@ -203,3 +203,150 @@ async def test_pause_or_resume_without_running_download(
with pytest.raises(download_module.DownloadNotRunningError): with pytest.raises(download_module.DownloadNotRunningError):
await manager.stop_download(object()) await manager.stop_download(object())
async def test_download_task_callback_executes_on_completion(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that _handle_download_task_done callback is executed when download completes."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
callback_executed = asyncio.Event()
original_callback = manager._handle_download_task_done
def tracking_callback(task, output_dir):
original_callback(task, output_dir)
callback_executed.set()
monkeypatch.setattr(
manager, "_handle_download_task_done", tracking_callback
)
async def fake_download(self, *_args):
# Simulate successful completion
async with self._state_lock:
self._progress["status"] = "completed"
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
# Wait for callback to execute
await asyncio.wait_for(callback_executed.wait(), timeout=1)
assert manager._progress["status"] == "completed"
async def test_download_task_callback_handles_errors(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that _handle_download_task_done properly handles task errors and saves progress."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
callback_executed = asyncio.Event()
progress_saved = False
original_save_progress = manager._save_progress
def tracking_save_progress(output_dir):
nonlocal progress_saved
progress_saved = True
return original_save_progress(output_dir)
monkeypatch.setattr(manager, "_save_progress", tracking_save_progress)
original_callback = manager._handle_download_task_done
def tracking_callback(task, output_dir):
original_callback(task, output_dir)
callback_executed.set()
monkeypatch.setattr(
manager, "_handle_download_task_done", tracking_callback
)
async def fake_download_with_error(self, *_args):
raise RuntimeError("Simulated download error")
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download_with_error,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
# Wait for callback to execute (it should handle the error)
await asyncio.wait_for(callback_executed.wait(), timeout=1)
# Progress should be saved even on error
assert progress_saved is True
async def test_get_status_returns_correct_state(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that get_status returns the correct download state."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
# Test idle state
status = await manager.get_status(object())
assert status["success"] is True
assert status["is_downloading"] is False
assert status["status"]["status"] == "idle"
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, *_args):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
# Start download
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
await asyncio.wait_for(started.wait(), timeout=1)
# Test running state
status = await manager.get_status(object())
assert status["success"] is True
assert status["is_downloading"] is True
assert status["status"]["status"] == "running"
# Cleanup
release.set()
if manager._download_task:
await asyncio.wait_for(manager._download_task, timeout=1)