mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #524 from willmiao/codex/add-async-tests-for-websocketmanager
Add websocket broadcast and usage stats tests
This commit is contained in:
@@ -21,6 +21,32 @@ def manager():
|
||||
return WebSocketManager()
|
||||
|
||||
|
||||
async def test_broadcast_init_progress_replays_cached_payloads(manager):
|
||||
first_payload = {"pageType": "loras", "progress": 15}
|
||||
second_payload = {"scanner_type": "loras", "progress": 45}
|
||||
|
||||
await manager.broadcast_init_progress(first_payload)
|
||||
await manager.broadcast_init_progress(second_payload)
|
||||
|
||||
replay_socket = DummyWebSocket()
|
||||
await manager._send_cached_init_progress(replay_socket)
|
||||
|
||||
assert replay_socket.messages == [
|
||||
{
|
||||
"pageType": "loras",
|
||||
"progress": 15,
|
||||
"stage": "processing",
|
||||
"details": "Processing...",
|
||||
},
|
||||
{
|
||||
"scanner_type": "loras",
|
||||
"progress": 45,
|
||||
"stage": "processing",
|
||||
"details": "Processing...",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_broadcast_init_progress_adds_defaults(manager):
|
||||
ws = DummyWebSocket()
|
||||
manager._init_websockets.add(ws)
|
||||
@@ -66,6 +92,18 @@ async def test_broadcast_download_progress_tracks_state(manager):
|
||||
assert manager.get_download_progress(download_id)["progress"] == 55
|
||||
|
||||
|
||||
async def test_broadcast_download_progress_to_multiple_updates(manager):
|
||||
ws = DummyWebSocket()
|
||||
download_id = "batch"
|
||||
manager._download_websockets[download_id] = ws
|
||||
|
||||
await manager.broadcast_download_progress(download_id, {"progress": 10})
|
||||
await manager.broadcast_download_progress(download_id, {"progress": 75})
|
||||
|
||||
assert ws.messages == [{"progress": 10}, {"progress": 75}]
|
||||
assert manager.get_download_progress(download_id)["progress"] == 75
|
||||
|
||||
|
||||
async def test_broadcast_download_progress_missing_socket(manager):
|
||||
await manager.broadcast_download_progress("missing", {"progress": 30})
|
||||
# Progress should be stored even without a live websocket
|
||||
@@ -84,6 +122,18 @@ async def test_auto_organize_progress_helpers(manager):
|
||||
assert manager.is_auto_organize_running() is False
|
||||
|
||||
|
||||
async def test_broadcast_auto_organize_progress_notifies_all_clients(manager):
|
||||
ws_primary = DummyWebSocket()
|
||||
ws_secondary = DummyWebSocket()
|
||||
manager._websockets.update({ws_primary, ws_secondary})
|
||||
|
||||
payload = {"status": "started", "progress": 5}
|
||||
await manager.broadcast_auto_organize_progress(payload)
|
||||
|
||||
assert ws_primary.messages == [payload]
|
||||
assert ws_secondary.messages == [payload]
|
||||
|
||||
|
||||
def test_cleanup_old_downloads(manager):
|
||||
now = datetime.now()
|
||||
manager._download_progress = {
|
||||
|
||||
153
tests/utils/test_usage_stats.py
Normal file
153
tests/utils/test_usage_stats.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
|
||||
|
||||
from py.metadata_collector.constants import LORAS, MODELS
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.utils import usage_stats as usage_stats_module
|
||||
from py.utils.usage_stats import UsageStats
|
||||
|
||||
|
||||
async def _finalize_usage_stats(tasks):
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
UsageStats._instance = None
|
||||
|
||||
|
||||
def _prepare_usage_stats(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, *, sleep_override=None):
|
||||
UsageStats._instance = None
|
||||
stats_root = tmp_path / "loras"
|
||||
stats_root.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setattr(usage_stats_module.config, "loras_roots", [str(stats_root)])
|
||||
|
||||
created_tasks = []
|
||||
real_create_task = usage_stats_module.asyncio.create_task
|
||||
|
||||
def _track_task(coro):
|
||||
task = real_create_task(coro)
|
||||
created_tasks.append(task)
|
||||
return task
|
||||
|
||||
monkeypatch.setattr(usage_stats_module.asyncio, "create_task", _track_task)
|
||||
|
||||
if sleep_override is not None:
|
||||
monkeypatch.setattr(usage_stats_module.asyncio, "sleep", sleep_override)
|
||||
|
||||
stats = UsageStats()
|
||||
return stats, created_tasks, stats_root
|
||||
|
||||
|
||||
async def test_usage_stats_converts_legacy_format(tmp_path, monkeypatch):
|
||||
legacy_stats = {
|
||||
"checkpoints": {"hash1": 3},
|
||||
"loras": {"hash2": 5},
|
||||
"total_executions": 9,
|
||||
"last_save_time": 123.0,
|
||||
}
|
||||
|
||||
UsageStats._instance = None
|
||||
stats_root = tmp_path / "loras"
|
||||
stats_root.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setattr(usage_stats_module.config, "loras_roots", [str(stats_root)])
|
||||
|
||||
stats_path = stats_root / UsageStats.STATS_FILENAME
|
||||
stats_path.write_text(json.dumps(legacy_stats), encoding="utf-8")
|
||||
|
||||
created_tasks = []
|
||||
real_create_task = usage_stats_module.asyncio.create_task
|
||||
|
||||
def _track_task(coro):
|
||||
task = real_create_task(coro)
|
||||
created_tasks.append(task)
|
||||
return task
|
||||
|
||||
monkeypatch.setattr(usage_stats_module.asyncio, "create_task", _track_task)
|
||||
|
||||
stats = UsageStats()
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
converted = stats.stats
|
||||
|
||||
assert converted["total_executions"] == 9
|
||||
assert converted["checkpoints"]["hash1"] == {"total": 3, "history": {today: 3}}
|
||||
assert converted["loras"]["hash2"] == {"total": 5, "history": {today: 5}}
|
||||
|
||||
backup_path = stats_path.with_suffix(stats_path.suffix + UsageStats.BACKUP_SUFFIX)
|
||||
assert backup_path.exists()
|
||||
|
||||
await _finalize_usage_stats(created_tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_save_stats_persists_file(tmp_path, monkeypatch):
|
||||
stats, tasks, stats_root = _prepare_usage_stats(tmp_path, monkeypatch)
|
||||
stats.stats["total_executions"] = 4
|
||||
|
||||
saved = await stats.save_stats(force=True)
|
||||
assert saved is True
|
||||
|
||||
stats_path = stats_root / UsageStats.STATS_FILENAME
|
||||
persisted = json.loads(stats_path.read_text(encoding="utf-8"))
|
||||
assert persisted["total_executions"] == 4
|
||||
assert persisted["last_save_time"] == stats.stats["last_save_time"]
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_background_processor_handles_pending_prompts(tmp_path, monkeypatch):
|
||||
real_sleep = usage_stats_module.asyncio.sleep
|
||||
|
||||
async def fast_sleep(_seconds):
|
||||
await real_sleep(0.01)
|
||||
|
||||
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch, sleep_override=fast_sleep)
|
||||
|
||||
metadata_calls = []
|
||||
metadata_payload = {
|
||||
MODELS: {
|
||||
"1": {"type": "checkpoint", "name": "model.ckpt"},
|
||||
},
|
||||
LORAS: {
|
||||
"2": {"lora_list": [{"name": "awesome_lora.safetensors"}]},
|
||||
},
|
||||
}
|
||||
|
||||
class FakeMetadataRegistry:
|
||||
def get_metadata(self, prompt_id):
|
||||
metadata_calls.append(prompt_id)
|
||||
return metadata_payload
|
||||
|
||||
monkeypatch.setattr(usage_stats_module, "MetadataRegistry", FakeMetadataRegistry, raising=False)
|
||||
|
||||
checkpoint_scanner = SimpleNamespace(get_hash_by_filename=lambda name: {"model": "ckpt-hash"}.get(name))
|
||||
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: {"awesome_lora.safetensors": "lora-hash"}.get(name))
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
|
||||
|
||||
save_spy = AsyncMock(return_value=True)
|
||||
monkeypatch.setattr(stats, "save_stats", save_spy)
|
||||
|
||||
stats.pending_prompt_ids.add("prompt-42")
|
||||
|
||||
await real_sleep(0.05)
|
||||
|
||||
assert metadata_calls == ["prompt-42"]
|
||||
assert stats.pending_prompt_ids == set()
|
||||
assert stats.stats["total_executions"] == 1
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
assert stats.stats["checkpoints"]["ckpt-hash"]["history"][today] == 1
|
||||
assert stats.stats["loras"]["lora-hash"]["history"][today] == 1
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
Reference in New Issue
Block a user