From 095320ef72b2d92d6d3262d9dc1fc994272af7e1 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 25 Sep 2025 09:40:25 +0800 Subject: [PATCH] test(routes): tidy lora route test imports --- tests/routes/test_embedding_routes.py | 50 +++++ tests/routes/test_lora_routes.py | 213 +++++++++++++++++++ tests/services/test_civitai_client.py | 222 ++++++++++++++++++++ tests/services/test_download_manager.py | 247 +++++++++++++++++++++++ tests/services/test_settings_manager.py | 61 ++++++ tests/services/test_websocket_manager.py | 84 ++++++++ 6 files changed, 877 insertions(+) create mode 100644 tests/routes/test_embedding_routes.py create mode 100644 tests/routes/test_lora_routes.py create mode 100644 tests/services/test_civitai_client.py create mode 100644 tests/services/test_download_manager.py create mode 100644 tests/services/test_settings_manager.py create mode 100644 tests/services/test_websocket_manager.py diff --git a/tests/routes/test_embedding_routes.py b/tests/routes/test_embedding_routes.py new file mode 100644 index 00000000..fc1782a0 --- /dev/null +++ b/tests/routes/test_embedding_routes.py @@ -0,0 +1,50 @@ +import json + +import pytest + +from py.routes.embedding_routes import EmbeddingRoutes + + +class DummyRequest: + def __init__(self, *, match_info=None): + self.match_info = match_info or {} + + +class StubEmbeddingService: + def __init__(self): + self.info = {} + + async def get_model_info_by_name(self, name): + value = self.info.get(name) + if isinstance(value, Exception): + raise value + return value + + +@pytest.fixture +def routes(): + handler = EmbeddingRoutes() + handler.service = StubEmbeddingService() + return handler + + +async def test_get_embedding_info_success(routes): + routes.service.info["demo"] = {"name": "demo"} + response = await routes.get_embedding_info(DummyRequest(match_info={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"name": "demo"} + + +async def test_get_embedding_info_missing(routes): + response = await routes.get_embedding_info(DummyRequest(match_info={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload == {"error": "Embedding not found"} + + +async def test_get_embedding_info_error(routes): + routes.service.info["demo"] = RuntimeError("boom") + response = await routes.get_embedding_info(DummyRequest(match_info={"name": "demo"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload == {"error": "boom"} diff --git a/tests/routes/test_lora_routes.py b/tests/routes/test_lora_routes.py new file mode 100644 index 00000000..2b447987 --- /dev/null +++ b/tests/routes/test_lora_routes.py @@ -0,0 +1,213 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from py.routes.lora_routes import LoraRoutes +from server import PromptServer + + +class DummyRequest: + def __init__(self, *, query=None, match_info=None, json_data=None): + self.query = query or {} + self.match_info = match_info or {} + self._json_data = json_data or {} + + async def json(self): + return self._json_data + + +class StubLoraService: + def __init__(self): + self.notes = {} + self.trigger_words = {} + self.usage_tips = {} + self.previews = {} + self.civitai = {} + + async def get_lora_notes(self, name): + return self.notes.get(name) + + async def get_lora_trigger_words(self, name): + return self.trigger_words.get(name, []) + + async def get_lora_usage_tips_by_relative_path(self, path): + return self.usage_tips.get(path) + + async def get_lora_preview_url(self, name): + return self.previews.get(name) + + async def get_lora_civitai_url(self, name): + return self.civitai.get(name, {"civitai_url": ""}) + + +@pytest.fixture +def routes(): + handler = LoraRoutes() + handler.service = StubLoraService() + return handler + + +async def test_get_lora_notes_success(routes): + routes.service.notes["demo"] = "Great notes" + request = DummyRequest(query={"name": "demo"}) + + response = await routes.get_lora_notes(request) + payload = json.loads(response.text) + + assert payload == {"success": True, "notes": "Great notes"} + + +async def test_get_lora_notes_missing_name(routes): + response = await routes.get_lora_notes(DummyRequest()) + assert response.status == 400 + assert response.text == "Lora file name is required" + + +async def test_get_lora_notes_not_found(routes): + response = await routes.get_lora_notes(DummyRequest(query={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload == {"success": False, "error": "LoRA not found in cache"} + + +async def test_get_lora_notes_error(routes, monkeypatch): + async def failing(*_args, **_kwargs): + raise RuntimeError("boom") + + routes.service.get_lora_notes = failing + + response = await routes.get_lora_notes(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + + assert response.status == 500 + assert payload["success"] is False + assert payload["error"] == "boom" + + +async def test_get_lora_trigger_words_success(routes): + routes.service.trigger_words["demo"] = ["trigger"] + response = await routes.get_lora_trigger_words(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"success": True, "trigger_words": ["trigger"]} + + +async def test_get_lora_trigger_words_missing_name(routes): + response = await routes.get_lora_trigger_words(DummyRequest()) + assert response.status == 400 + + +async def test_get_lora_trigger_words_error(routes): + async def failing(*_args, **_kwargs): + raise RuntimeError("fail") + + routes.service.get_lora_trigger_words = failing + + response = await routes.get_lora_trigger_words(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False + + +async def test_get_usage_tips_success(routes): + routes.service.usage_tips["path"] = "tips" + response = await routes.get_lora_usage_tips_by_path(DummyRequest(query={"relative_path": "path"})) + payload = json.loads(response.text) + assert payload == {"success": True, "usage_tips": "tips"} + + +async def test_get_usage_tips_missing_param(routes): + response = await routes.get_lora_usage_tips_by_path(DummyRequest()) + assert response.status == 400 + + +async def test_get_usage_tips_error(routes): + async def failing(*_args, **_kwargs): + raise RuntimeError("bad") + + routes.service.get_lora_usage_tips_by_relative_path = failing + response = await routes.get_lora_usage_tips_by_path(DummyRequest(query={"relative_path": "path"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False + + +async def test_get_preview_url_success(routes): + routes.service.previews["demo"] = "http://preview" + response = await routes.get_lora_preview_url(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"success": True, "preview_url": "http://preview"} + + +async def test_get_preview_url_missing(routes): + response = await routes.get_lora_preview_url(DummyRequest()) + assert response.status == 400 + + +async def test_get_preview_url_not_found(routes): + response = await routes.get_lora_preview_url(DummyRequest(query={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload["success"] is False + + +async def test_get_civitai_url_success(routes): + routes.service.civitai["demo"] = {"civitai_url": "https://civitai.com"} + response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"success": True, "civitai_url": "https://civitai.com"} + + +async def test_get_civitai_url_missing(routes): + response = await routes.get_lora_civitai_url(DummyRequest()) + assert response.status == 400 + + +async def test_get_civitai_url_not_found(routes): + response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload["success"] is False + + +async def test_get_civitai_url_error(routes): + async def failing(*_args, **_kwargs): + raise RuntimeError("oops") + + routes.service.get_lora_civitai_url = failing + response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False + + +async def test_get_trigger_words_broadcasts(monkeypatch, routes): + send_mock = MagicMock() + PromptServer.instance = SimpleNamespace(send_sync=send_mock) + + monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"])) + + request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]}) + + response = await routes.get_trigger_words(request) + payload = json.loads(response.text) + + assert payload == {"success": True} + send_mock.assert_called_once_with( + "trigger_word_update", + {"id": "node", "message": "trigger-one"}, + ) + + +async def test_get_trigger_words_error(monkeypatch, routes): + async def failing_json(): + raise RuntimeError("bad json") + + request = DummyRequest(json_data=None) + request.json = failing_json + + response = await routes.get_trigger_words(request) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py new file mode 100644 index 00000000..f5283443 --- /dev/null +++ b/tests/services/test_civitai_client.py @@ -0,0 +1,222 @@ +from unittest.mock import AsyncMock + +import pytest + +from py.services import civitai_client as civitai_client_module +from py.services.civitai_client import CivitaiClient +from py.services.model_metadata_provider import ModelMetadataProviderManager + + +class DummyDownloader: + def __init__(self): + self.download_calls = [] + self.memory_calls = [] + self.request_calls = [] + + async def download_file(self, **kwargs): + self.download_calls.append(kwargs) + return True, kwargs["save_path"] + + async def download_to_memory(self, url, use_auth=False): + self.memory_calls.append({"url": url, "use_auth": use_auth}) + return True, b"bytes", {"content-type": "image/jpeg"} + + async def make_request(self, method, url, use_auth=True): + self.request_calls.append({"method": method, "url": url, "use_auth": use_auth}) + return True, {} + + +@pytest.fixture(autouse=True) +def reset_singletons(): + CivitaiClient._instance = None + ModelMetadataProviderManager._instance = None + yield + CivitaiClient._instance = None + ModelMetadataProviderManager._instance = None + + +@pytest.fixture +def downloader(monkeypatch): + instance = DummyDownloader() + monkeypatch.setattr(civitai_client_module, "get_downloader", AsyncMock(return_value=instance)) + return instance + + +async def test_download_file_uses_downloader(tmp_path, downloader): + client = await CivitaiClient.get_instance() + save_dir = tmp_path / "files" + save_dir.mkdir() + + success, path = await client.download_file( + url="https://example.invalid/model", + save_dir=str(save_dir), + default_filename="model.safetensors", + ) + + assert success is True + assert path == str(save_dir / "model.safetensors") + assert downloader.download_calls[0]["use_auth"] is True + + +async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): + version_payload = { + "modelId": 123, + "model": {"description": "", "tags": []}, + "creator": {}, + } + model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} + + async def fake_make_request(method, url, use_auth=True): + if url.endswith("by-hash/hash"): + return True, version_payload.copy() + if url.endswith("/models/123"): + return True, model_payload + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_by_hash("hash") + + assert error is None + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + + +async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return False, "not found" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_by_hash("missing") + + assert result is None + assert error == "Model not found" + + +async def test_download_preview_image_writes_file(tmp_path, downloader): + client = await CivitaiClient.get_instance() + target = tmp_path / "preview" / "image.jpg" + + success = await client.download_preview_image("https://example.invalid/preview", str(target)) + + assert success is True + assert target.exists() + assert target.read_bytes() == b"bytes" + + +async def test_download_preview_image_failure(monkeypatch, downloader): + async def failing_download(url, use_auth=False): + return False, b"", {} + + downloader.download_to_memory = failing_download + + client = await CivitaiClient.get_instance() + target = "/tmp/ignored.jpg" + + success = await client.download_preview_image("https://example.invalid/preview", target) + + assert success is False + + +async def test_get_model_versions_success(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_versions("123") + + assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} + + +async def test_get_model_version_by_version_id(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + if url.endswith("/model-versions/7"): + return True, { + "modelId": 321, + "model": {"description": ""}, + "files": [], + } + if url.endswith("/models/321"): + return True, {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_version(version_id=7) + + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + + +async def test_get_model_version_requires_identifier(monkeypatch, downloader): + client = await CivitaiClient.get_instance() + result = await client.get_model_version() + assert result is None + + +async def test_get_model_version_info_handles_not_found(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return False, "not found" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_version_info("55") + + assert result is None + assert error == "Model not found" + + +async def test_get_model_version_info_success(monkeypatch, downloader): + expected = {"id": 55} + + async def fake_make_request(method, url, use_auth=True): + return True, expected + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_version_info("55") + + assert result == expected + assert error is None + + +async def test_get_image_info_returns_first_item(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return True, {"items": [{"id": 1}, {"id": 2}]} + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_image_info("42") + + assert result == {"id": 1} + + +async def test_get_image_info_handles_missing(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return True, {"items": []} + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_image_info("42") + + assert result is None diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py new file mode 100644 index 00000000..80701e7d --- /dev/null +++ b/tests/services/test_download_manager.py @@ -0,0 +1,247 @@ +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from py.services.download_manager import DownloadManager +from py.services import download_manager +from py.services.service_registry import ServiceRegistry +from py.services.settings_manager import settings + + +@pytest.fixture(autouse=True) +def reset_download_manager(): + """Ensure each test operates on a fresh singleton.""" + DownloadManager._instance = None + yield + DownloadManager._instance = None + + +@pytest.fixture(autouse=True) +def isolate_settings(monkeypatch, tmp_path): + """Point settings writes at a temporary directory to avoid touching real files.""" + default_settings = settings._get_default_settings() + default_settings.update( + { + "default_lora_root": str(tmp_path), + "default_checkpoint_root": str(tmp_path / "checkpoints"), + "default_embedding_root": str(tmp_path / "embeddings"), + "download_path_templates": { + "lora": "{base_model}/{first_tag}", + "checkpoint": "{base_model}/{first_tag}", + "embedding": "{base_model}/{first_tag}", + }, + "base_model_path_mappings": {"BaseModel": "MappedModel"}, + } + ) + monkeypatch.setattr(settings, "settings", default_settings) + monkeypatch.setattr(type(settings), "_save_settings", lambda self: None) + + +@pytest.fixture(autouse=True) +def stub_metadata(monkeypatch): + class _StubMetadata: + def __init__(self, save_path: str): + self.file_path = save_path + self.sha256 = "sha256" + self.file_name = Path(save_path).stem + + def _factory(save_path: str): + return _StubMetadata(save_path) + + def _make_class(): + @staticmethod + def from_civitai_info(_version_info, _file_info, save_path): + return _factory(save_path) + + return type("StubMetadata", (), {"from_civitai_info": from_civitai_info}) + + stub_class = _make_class() + monkeypatch.setattr(download_manager, "LoraMetadata", stub_class) + monkeypatch.setattr(download_manager, "CheckpointMetadata", stub_class) + monkeypatch.setattr(download_manager, "EmbeddingMetadata", stub_class) + + +class DummyScanner: + def __init__(self, exists: bool = False): + self.exists = exists + self.calls = [] + + async def check_model_version_exists(self, version_id): + self.calls.append(version_id) + return self.exists + + +@pytest.fixture +def scanners(monkeypatch): + lora_scanner = DummyScanner() + checkpoint_scanner = DummyScanner() + embedding_scanner = DummyScanner() + + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=embedding_scanner)) + + return SimpleNamespace( + lora=lora_scanner, + checkpoint=checkpoint_scanner, + embedding=embedding_scanner, + ) + + +@pytest.fixture +def metadata_provider(monkeypatch): + class DummyProvider: + def __init__(self): + self.calls = [] + + async def get_model_version(self, model_id, model_version_id): + self.calls.append((model_id, model_version_id)) + return { + "id": 42, + "model": {"type": "LoRA", "tags": ["fantasy"]}, + "baseModel": "BaseModel", + "creator": {"username": "Author"}, + "files": [ + { + "primary": True, + "downloadUrl": "https://example.invalid/file.safetensors", + "name": "file.safetensors", + } + ], + } + + provider = DummyProvider() + monkeypatch.setattr( + download_manager, + "get_default_metadata_provider", + AsyncMock(return_value=provider), + ) + return provider + + +@pytest.fixture(autouse=True) +def noop_cleanup(monkeypatch): + async def _cleanup(self, task_id): + if task_id in self._active_downloads: + self._active_downloads[task_id]["cleaned"] = True + + monkeypatch.setattr(DownloadManager, "_cleanup_download_record", _cleanup) + + +async def test_download_requires_identifier(): + manager = DownloadManager() + result = await manager.download_from_civitai() + assert result == { + "success": False, + "error": "Either model_id or model_version_id must be provided", + } + + +async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata_provider, tmp_path): + manager = DownloadManager() + + captured = {} + + async def fake_execute_download( + self, + *, + download_url, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + ): + captured.update( + { + "download_url": download_url, + "save_dir": Path(save_dir), + "relative_path": relative_path, + "progress_callback": progress_callback, + "model_type": model_type, + "download_id": download_id, + "metadata_path": metadata.file_path, + } + ) + return {"success": True} + + monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False) + + result = await manager.download_from_civitai( + model_version_id=99, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source=None, + ) + + assert result["success"] is True + assert "download_id" in result + assert manager._download_tasks == {} + assert manager._active_downloads[result["download_id"]]["status"] == "completed" + + assert captured["relative_path"] == "MappedModel/fantasy" + expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy" + assert captured["save_dir"] == expected_dir + assert captured["model_type"] == "lora" + + +async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider): + scanners.lora.exists = True + + manager = DownloadManager() + + execute_mock = AsyncMock(return_value={"success": True}) + monkeypatch.setattr(DownloadManager, "_execute_download", execute_mock) + + result = await manager.download_from_civitai(model_version_id=101, save_dir="/tmp") + + assert result["success"] is False + assert result["error"] == "Model version already exists in lora library" + assert "download_id" in result + assert execute_mock.await_count == 0 + + +async def test_download_handles_metadata_errors(monkeypatch, scanners): + async def failing_provider(*_args, **_kwargs): + return None + + monkeypatch.setattr( + download_manager, + "get_default_metadata_provider", + AsyncMock(return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None))), + ) + + manager = DownloadManager() + + result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp") + + assert result["success"] is False + assert result["error"] == "Failed to fetch model metadata" + assert "download_id" in result + + +async def test_download_rejects_unsupported_model_type(monkeypatch, scanners): + class Provider: + async def get_model_version(self, *_args, **_kwargs): + return { + "model": {"type": "Unsupported", "tags": []}, + "files": [], + } + + monkeypatch.setattr( + download_manager, + "get_default_metadata_provider", + AsyncMock(return_value=Provider()), + ) + + manager = DownloadManager() + + result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp") + + assert result["success"] is False + assert result["error"].startswith("Model type") diff --git a/tests/services/test_settings_manager.py b/tests/services/test_settings_manager.py new file mode 100644 index 00000000..7e547680 --- /dev/null +++ b/tests/services/test_settings_manager.py @@ -0,0 +1,61 @@ +import json + +import pytest + +from py.services.settings_manager import SettingsManager + + +@pytest.fixture +def manager(tmp_path, monkeypatch): + monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) + mgr = SettingsManager() + mgr.settings_file = str(tmp_path / "settings.json") + return mgr + + +def test_environment_variable_overrides_settings(tmp_path, monkeypatch): + monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) + monkeypatch.setenv("CIVITAI_API_KEY", "secret") + mgr = SettingsManager() + mgr.settings_file = str(tmp_path / "settings.json") + + assert mgr.get("civitai_api_key") == "secret" + + +def test_download_path_template_parses_json_string(manager): + templates = {"lora": "{author}", "checkpoint": "{author}", "embedding": "{author}"} + manager.settings["download_path_templates"] = json.dumps(templates) + + template = manager.get_download_path_template("lora") + + assert template == "{author}" + assert isinstance(manager.settings["download_path_templates"], dict) + + +def test_download_path_template_invalid_json(manager): + manager.settings["download_path_templates"] = "not json" + + template = manager.get_download_path_template("checkpoint") + + assert template == "{base_model}/{first_tag}" + assert manager.settings["download_path_templates"]["lora"] == "{base_model}/{first_tag}" + + +def test_auto_set_default_roots(manager): + manager.settings["folder_paths"] = { + "loras": ["/loras"], + "checkpoints": ["/checkpoints"], + "embeddings": ["/embeddings"], + } + + manager._auto_set_default_roots() + + assert manager.get("default_lora_root") == "/loras" + assert manager.get("default_checkpoint_root") == "/checkpoints" + assert manager.get("default_embedding_root") == "/embeddings" + + +def test_delete_setting(manager): + manager.set("example", 1) + manager.delete("example") + assert manager.get("example") is None diff --git a/tests/services/test_websocket_manager.py b/tests/services/test_websocket_manager.py new file mode 100644 index 00000000..b85c2197 --- /dev/null +++ b/tests/services/test_websocket_manager.py @@ -0,0 +1,84 @@ +from datetime import datetime, timedelta + +import pytest + +from py.services.websocket_manager import WebSocketManager + + +class DummyWebSocket: + def __init__(self): + self.messages = [] + self.closed = False + + async def send_json(self, data): + if self.closed: + raise RuntimeError("WebSocket closed") + self.messages.append(data) + + +@pytest.fixture +def manager(): + return WebSocketManager() + + +async def test_broadcast_init_progress_adds_defaults(manager): + ws = DummyWebSocket() + manager._init_websockets.add(ws) + + await manager.broadcast_init_progress({}) + + assert ws.messages == [ + { + "stage": "processing", + "progress": 0, + "details": "Processing...", + } + ] + + +async def test_broadcast_download_progress_tracks_state(manager): + ws = DummyWebSocket() + download_id = "abc" + manager._download_websockets[download_id] = ws + + await manager.broadcast_download_progress(download_id, {"progress": 55}) + + assert ws.messages == [{"progress": 55}] + assert manager.get_download_progress(download_id)["progress"] == 55 + + +async def test_broadcast_download_progress_missing_socket(manager): + await manager.broadcast_download_progress("missing", {"progress": 30}) + # Progress should be stored even without a live websocket + assert manager.get_download_progress("missing")["progress"] == 30 + + +async def test_auto_organize_progress_helpers(manager): + payload = {"status": "processing", "progress": 10} + await manager.broadcast_auto_organize_progress(payload) + + assert manager.get_auto_organize_progress() == payload + assert manager.is_auto_organize_running() is True + + manager.cleanup_auto_organize_progress() + assert manager.get_auto_organize_progress() is None + assert manager.is_auto_organize_running() is False + + +def test_cleanup_old_downloads(manager): + now = datetime.now() + manager._download_progress = { + "recent": {"progress": 10, "timestamp": now}, + "stale": {"progress": 100, "timestamp": now - timedelta(hours=48)}, + } + + manager.cleanup_old_downloads(max_age_hours=24) + + assert "stale" not in manager._download_progress + assert "recent" in manager._download_progress + + +def test_generate_download_id(manager): + download_id = manager.generate_download_id() + assert isinstance(download_id, str) + assert download_id