mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(example-images): add check pending models endpoint and improve async handling
- Add /api/example-images/check-pending endpoint to quickly check models needing downloads - Improve DownloadManager.start_download() to return immediately without blocking - Add _handle_download_task_done callback for proper error handling and progress saving - Add check_pending_models() method for lightweight pre-download validation - Update frontend ExampleImagesManager to use new check-pending endpoint - Add comprehensive tests for new functionality
This commit is contained in:
368
tests/services/test_check_pending_models.py
Normal file
368
tests/services/test_check_pending_models.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Tests for the check_pending_models lightweight pre-check functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
from py.utils import example_images_download_manager as download_module
|
||||
|
||||
|
||||
class StubScanner:
|
||||
"""Scanner double returning predetermined cache contents."""
|
||||
|
||||
def __init__(self, models: list[dict]) -> None:
|
||||
self._cache = SimpleNamespace(raw_data=models)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
def _patch_scanners(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
lora_scanner: StubScanner | None = None,
|
||||
checkpoint_scanner: StubScanner | None = None,
|
||||
embedding_scanner: StubScanner | None = None,
|
||||
) -> None:
|
||||
"""Patch ServiceRegistry to return stub scanners."""
|
||||
|
||||
async def _get_lora_scanner(cls):
|
||||
return lora_scanner or StubScanner([])
|
||||
|
||||
async def _get_checkpoint_scanner(cls):
|
||||
return checkpoint_scanner or StubScanner([])
|
||||
|
||||
async def _get_embedding_scanner(cls):
|
||||
return embedding_scanner or StubScanner([])
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_lora_scanner",
|
||||
classmethod(_get_lora_scanner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_checkpoint_scanner",
|
||||
classmethod(_get_checkpoint_scanner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_embedding_scanner",
|
||||
classmethod(_get_embedding_scanner),
|
||||
)
|
||||
|
||||
|
||||
class RecordingWebSocketManager:
|
||||
"""Collects broadcast payloads for assertions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.payloads: list[dict] = []
|
||||
|
||||
async def broadcast(self, payload: dict) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_returns_zero_when_all_processed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models returns 0 pending when all models are processed."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Create processed models
|
||||
processed_hashes = ["a" * 64, "b" * 64, "c" * 64]
|
||||
models = [
|
||||
{"sha256": h, "model_name": f"Model {i}"}
|
||||
for i, h in enumerate(processed_hashes)
|
||||
]
|
||||
|
||||
# Create progress file with all models processed
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text(
|
||||
json.dumps({"processed_models": processed_hashes, "failed_models": []}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Create model directories with files (simulating completed downloads)
|
||||
for h in processed_hashes:
|
||||
model_dir = tmp_path / h
|
||||
model_dir.mkdir()
|
||||
(model_dir / "image_0.png").write_text("data")
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["is_downloading"] is False
|
||||
assert result["total_models"] == 3
|
||||
assert result["pending_count"] == 0
|
||||
assert result["processed_count"] == 3
|
||||
assert result["needs_download"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_finds_unprocessed_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models correctly identifies unprocessed models."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Create models - some processed, some not
|
||||
processed_hash = "a" * 64
|
||||
unprocessed_hash = "b" * 64
|
||||
models = [
|
||||
{"sha256": processed_hash, "model_name": "Processed Model"},
|
||||
{"sha256": unprocessed_hash, "model_name": "Unprocessed Model"},
|
||||
]
|
||||
|
||||
# Create progress file with only one model processed
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text(
|
||||
json.dumps({"processed_models": [processed_hash], "failed_models": []}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Create directory only for processed model
|
||||
processed_dir = tmp_path / processed_hash
|
||||
processed_dir.mkdir()
|
||||
(processed_dir / "image_0.png").write_text("data")
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 2
|
||||
assert result["pending_count"] == 1
|
||||
assert result["processed_count"] == 1
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_skips_models_without_hash(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that models without sha256 are not counted as pending."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Models - one with hash, one without
|
||||
models = [
|
||||
{"sha256": "a" * 64, "model_name": "Hashed Model"},
|
||||
{"sha256": None, "model_name": "No Hash Model"},
|
||||
{"model_name": "Missing Hash Model"}, # No sha256 key at all
|
||||
]
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 3
|
||||
assert result["pending_count"] == 1 # Only the one with hash
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_multiple_model_types(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models aggregates counts across multiple model types."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
lora_models = [
|
||||
{"sha256": "a" * 64, "model_name": "Lora 1"},
|
||||
{"sha256": "b" * 64, "model_name": "Lora 2"},
|
||||
]
|
||||
checkpoint_models = [
|
||||
{"sha256": "c" * 64, "model_name": "Checkpoint 1"},
|
||||
]
|
||||
embedding_models = [
|
||||
{"sha256": "d" * 64, "model_name": "Embedding 1"},
|
||||
{"sha256": "e" * 64, "model_name": "Embedding 2"},
|
||||
{"sha256": "f" * 64, "model_name": "Embedding 3"},
|
||||
]
|
||||
|
||||
_patch_scanners(
|
||||
monkeypatch,
|
||||
lora_scanner=StubScanner(lora_models),
|
||||
checkpoint_scanner=StubScanner(checkpoint_models),
|
||||
embedding_scanner=StubScanner(embedding_models),
|
||||
)
|
||||
|
||||
result = await manager.check_pending_models(["lora", "checkpoint", "embedding"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 6 # 2 + 1 + 3
|
||||
assert result["pending_count"] == 6 # All unprocessed
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_returns_error_when_download_in_progress(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models returns special response when download is running."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Simulate download in progress
|
||||
manager._is_downloading = True
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["is_downloading"] is True
|
||||
assert result["needs_download"] is False
|
||||
assert result["pending_count"] == 0
|
||||
assert "already in progress" in result["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_empty_library(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models handles empty model library."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner([]))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 0
|
||||
assert result["pending_count"] == 0
|
||||
assert result["processed_count"] == 0
|
||||
assert result["needs_download"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_reads_failed_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models correctly reports failed model count."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||
|
||||
# Create progress file with failed models
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text(
|
||||
json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["failed_count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_missing_progress_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models works correctly when no progress file exists."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [
|
||||
{"sha256": "a" * 64, "model_name": "Model 1"},
|
||||
{"sha256": "b" * 64, "model_name": "Model 2"},
|
||||
]
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
# No progress file created
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 2
|
||||
assert result["pending_count"] == 2 # All pending since no progress
|
||||
assert result["processed_count"] == 0
|
||||
assert result["failed_count"] == 0
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_corrupted_progress_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models handles corrupted progress file gracefully."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||
|
||||
# Create corrupted progress file
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text("not valid json", encoding="utf-8")
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
# Should still succeed, treating all as unprocessed
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 1
|
||||
assert result["pending_count"] == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_manager():
|
||||
return get_settings_manager()
|
||||
Reference in New Issue
Block a user