diff --git a/tests/services/test_websocket_manager.py b/tests/services/test_websocket_manager.py index 9fc5250e..9f512aad 100644 --- a/tests/services/test_websocket_manager.py +++ b/tests/services/test_websocket_manager.py @@ -21,6 +21,32 @@ def manager(): return WebSocketManager() +async def test_broadcast_init_progress_replays_cached_payloads(manager): + first_payload = {"pageType": "loras", "progress": 15} + second_payload = {"scanner_type": "loras", "progress": 45} + + await manager.broadcast_init_progress(first_payload) + await manager.broadcast_init_progress(second_payload) + + replay_socket = DummyWebSocket() + await manager._send_cached_init_progress(replay_socket) + + assert replay_socket.messages == [ + { + "pageType": "loras", + "progress": 15, + "stage": "processing", + "details": "Processing...", + }, + { + "scanner_type": "loras", + "progress": 45, + "stage": "processing", + "details": "Processing...", + }, + ] + + async def test_broadcast_init_progress_adds_defaults(manager): ws = DummyWebSocket() manager._init_websockets.add(ws) @@ -66,6 +92,18 @@ async def test_broadcast_download_progress_tracks_state(manager): assert manager.get_download_progress(download_id)["progress"] == 55 +async def test_broadcast_download_progress_to_multiple_updates(manager): + ws = DummyWebSocket() + download_id = "batch" + manager._download_websockets[download_id] = ws + + await manager.broadcast_download_progress(download_id, {"progress": 10}) + await manager.broadcast_download_progress(download_id, {"progress": 75}) + + assert ws.messages == [{"progress": 10}, {"progress": 75}] + assert manager.get_download_progress(download_id)["progress"] == 75 + + async def test_broadcast_download_progress_missing_socket(manager): await manager.broadcast_download_progress("missing", {"progress": 30}) # Progress should be stored even without a live websocket @@ -84,6 +122,18 @@ async def test_auto_organize_progress_helpers(manager): assert manager.is_auto_organize_running() is False +async def test_broadcast_auto_organize_progress_notifies_all_clients(manager): + ws_primary = DummyWebSocket() + ws_secondary = DummyWebSocket() + manager._websockets.update({ws_primary, ws_secondary}) + + payload = {"status": "started", "progress": 5} + await manager.broadcast_auto_organize_progress(payload) + + assert ws_primary.messages == [payload] + assert ws_secondary.messages == [payload] + + def test_cleanup_old_downloads(manager): now = datetime.now() manager._download_progress = { diff --git a/tests/utils/test_usage_stats.py b/tests/utils/test_usage_stats.py new file mode 100644 index 00000000..213d47cb --- /dev/null +++ b/tests/utils/test_usage_stats.py @@ -0,0 +1,153 @@ +import asyncio +import contextlib +import json +import os +from datetime import datetime +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") + +from py.metadata_collector.constants import LORAS, MODELS +from py.services.service_registry import ServiceRegistry +from py.utils import usage_stats as usage_stats_module +from py.utils.usage_stats import UsageStats + + +async def _finalize_usage_stats(tasks): + for task in tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + UsageStats._instance = None + + +def _prepare_usage_stats(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, *, sleep_override=None): + UsageStats._instance = None + stats_root = tmp_path / "loras" + stats_root.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(usage_stats_module.config, "loras_roots", [str(stats_root)]) + + created_tasks = [] + real_create_task = usage_stats_module.asyncio.create_task + + def _track_task(coro): + task = real_create_task(coro) + created_tasks.append(task) + return task + + monkeypatch.setattr(usage_stats_module.asyncio, "create_task", _track_task) + + if sleep_override is not None: + monkeypatch.setattr(usage_stats_module.asyncio, "sleep", sleep_override) + + stats = UsageStats() + return stats, created_tasks, stats_root + + +async def test_usage_stats_converts_legacy_format(tmp_path, monkeypatch): + legacy_stats = { + "checkpoints": {"hash1": 3}, + "loras": {"hash2": 5}, + "total_executions": 9, + "last_save_time": 123.0, + } + + UsageStats._instance = None + stats_root = tmp_path / "loras" + stats_root.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(usage_stats_module.config, "loras_roots", [str(stats_root)]) + + stats_path = stats_root / UsageStats.STATS_FILENAME + stats_path.write_text(json.dumps(legacy_stats), encoding="utf-8") + + created_tasks = [] + real_create_task = usage_stats_module.asyncio.create_task + + def _track_task(coro): + task = real_create_task(coro) + created_tasks.append(task) + return task + + monkeypatch.setattr(usage_stats_module.asyncio, "create_task", _track_task) + + stats = UsageStats() + + today = datetime.now().strftime("%Y-%m-%d") + converted = stats.stats + + assert converted["total_executions"] == 9 + assert converted["checkpoints"]["hash1"] == {"total": 3, "history": {today: 3}} + assert converted["loras"]["hash2"] == {"total": 5, "history": {today: 5}} + + backup_path = stats_path.with_suffix(stats_path.suffix + UsageStats.BACKUP_SUFFIX) + assert backup_path.exists() + + await _finalize_usage_stats(created_tasks) + + +async def test_usage_stats_save_stats_persists_file(tmp_path, monkeypatch): + stats, tasks, stats_root = _prepare_usage_stats(tmp_path, monkeypatch) + stats.stats["total_executions"] = 4 + + saved = await stats.save_stats(force=True) + assert saved is True + + stats_path = stats_root / UsageStats.STATS_FILENAME + persisted = json.loads(stats_path.read_text(encoding="utf-8")) + assert persisted["total_executions"] == 4 + assert persisted["last_save_time"] == stats.stats["last_save_time"] + + await _finalize_usage_stats(tasks) + + +async def test_usage_stats_background_processor_handles_pending_prompts(tmp_path, monkeypatch): + real_sleep = usage_stats_module.asyncio.sleep + + async def fast_sleep(_seconds): + await real_sleep(0.01) + + stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch, sleep_override=fast_sleep) + + metadata_calls = [] + metadata_payload = { + MODELS: { + "1": {"type": "checkpoint", "name": "model.ckpt"}, + }, + LORAS: { + "2": {"lora_list": [{"name": "awesome_lora.safetensors"}]}, + }, + } + + class FakeMetadataRegistry: + def get_metadata(self, prompt_id): + metadata_calls.append(prompt_id) + return metadata_payload + + monkeypatch.setattr(usage_stats_module, "MetadataRegistry", FakeMetadataRegistry) + + checkpoint_scanner = SimpleNamespace(get_hash_by_filename=lambda name: {"model": "ckpt-hash"}.get(name)) + lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: {"awesome_lora.safetensors": "lora-hash"}.get(name)) + + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) + + save_spy = AsyncMock(return_value=True) + monkeypatch.setattr(stats, "save_stats", save_spy) + + stats.pending_prompt_ids.add("prompt-42") + + await real_sleep(0.05) + + assert metadata_calls == ["prompt-42"] + assert stats.pending_prompt_ids == set() + assert stats.stats["total_executions"] == 1 + + today = datetime.now().strftime("%Y-%m-%d") + assert stats.stats["checkpoints"]["ckpt-hash"]["history"][today] == 1 + assert stats.stats["loras"]["lora-hash"]["history"][today] == 1 + + await _finalize_usage_stats(tasks)