mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #482 from willmiao/codex/add-pytest-modules-for-untested-services
Add backend service and route test coverage
This commit is contained in:
50
tests/routes/test_embedding_routes.py
Normal file
50
tests/routes/test_embedding_routes.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.embedding_routes import EmbeddingRoutes
|
||||
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, *, match_info=None):
|
||||
self.match_info = match_info or {}
|
||||
|
||||
|
||||
class StubEmbeddingService:
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
|
||||
async def get_model_info_by_name(self, name):
|
||||
value = self.info.get(name)
|
||||
if isinstance(value, Exception):
|
||||
raise value
|
||||
return value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def routes():
|
||||
handler = EmbeddingRoutes()
|
||||
handler.service = StubEmbeddingService()
|
||||
return handler
|
||||
|
||||
|
||||
async def test_get_embedding_info_success(routes):
|
||||
routes.service.info["demo"] = {"name": "demo"}
|
||||
response = await routes.get_embedding_info(DummyRequest(match_info={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert payload == {"name": "demo"}
|
||||
|
||||
|
||||
async def test_get_embedding_info_missing(routes):
|
||||
response = await routes.get_embedding_info(DummyRequest(match_info={"name": "missing"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 404
|
||||
assert payload == {"error": "Embedding not found"}
|
||||
|
||||
|
||||
async def test_get_embedding_info_error(routes):
|
||||
routes.service.info["demo"] = RuntimeError("boom")
|
||||
response = await routes.get_embedding_info(DummyRequest(match_info={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 500
|
||||
assert payload == {"error": "boom"}
|
||||
213
tests/routes/test_lora_routes.py
Normal file
213
tests/routes/test_lora_routes.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.lora_routes import LoraRoutes
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, *, query=None, match_info=None, json_data=None):
|
||||
self.query = query or {}
|
||||
self.match_info = match_info or {}
|
||||
self._json_data = json_data or {}
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
class StubLoraService:
|
||||
def __init__(self):
|
||||
self.notes = {}
|
||||
self.trigger_words = {}
|
||||
self.usage_tips = {}
|
||||
self.previews = {}
|
||||
self.civitai = {}
|
||||
|
||||
async def get_lora_notes(self, name):
|
||||
return self.notes.get(name)
|
||||
|
||||
async def get_lora_trigger_words(self, name):
|
||||
return self.trigger_words.get(name, [])
|
||||
|
||||
async def get_lora_usage_tips_by_relative_path(self, path):
|
||||
return self.usage_tips.get(path)
|
||||
|
||||
async def get_lora_preview_url(self, name):
|
||||
return self.previews.get(name)
|
||||
|
||||
async def get_lora_civitai_url(self, name):
|
||||
return self.civitai.get(name, {"civitai_url": ""})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def routes():
|
||||
handler = LoraRoutes()
|
||||
handler.service = StubLoraService()
|
||||
return handler
|
||||
|
||||
|
||||
async def test_get_lora_notes_success(routes):
|
||||
routes.service.notes["demo"] = "Great notes"
|
||||
request = DummyRequest(query={"name": "demo"})
|
||||
|
||||
response = await routes.get_lora_notes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload == {"success": True, "notes": "Great notes"}
|
||||
|
||||
|
||||
async def test_get_lora_notes_missing_name(routes):
|
||||
response = await routes.get_lora_notes(DummyRequest())
|
||||
assert response.status == 400
|
||||
assert response.text == "Lora file name is required"
|
||||
|
||||
|
||||
async def test_get_lora_notes_not_found(routes):
|
||||
response = await routes.get_lora_notes(DummyRequest(query={"name": "missing"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 404
|
||||
assert payload == {"success": False, "error": "LoRA not found in cache"}
|
||||
|
||||
|
||||
async def test_get_lora_notes_error(routes, monkeypatch):
|
||||
async def failing(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
routes.service.get_lora_notes = failing
|
||||
|
||||
response = await routes.get_lora_notes(DummyRequest(query={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 500
|
||||
assert payload["success"] is False
|
||||
assert payload["error"] == "boom"
|
||||
|
||||
|
||||
async def test_get_lora_trigger_words_success(routes):
|
||||
routes.service.trigger_words["demo"] = ["trigger"]
|
||||
response = await routes.get_lora_trigger_words(DummyRequest(query={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert payload == {"success": True, "trigger_words": ["trigger"]}
|
||||
|
||||
|
||||
async def test_get_lora_trigger_words_missing_name(routes):
|
||||
response = await routes.get_lora_trigger_words(DummyRequest())
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_get_lora_trigger_words_error(routes):
|
||||
async def failing(*_args, **_kwargs):
|
||||
raise RuntimeError("fail")
|
||||
|
||||
routes.service.get_lora_trigger_words = failing
|
||||
|
||||
response = await routes.get_lora_trigger_words(DummyRequest(query={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 500
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_get_usage_tips_success(routes):
|
||||
routes.service.usage_tips["path"] = "tips"
|
||||
response = await routes.get_lora_usage_tips_by_path(DummyRequest(query={"relative_path": "path"}))
|
||||
payload = json.loads(response.text)
|
||||
assert payload == {"success": True, "usage_tips": "tips"}
|
||||
|
||||
|
||||
async def test_get_usage_tips_missing_param(routes):
|
||||
response = await routes.get_lora_usage_tips_by_path(DummyRequest())
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_get_usage_tips_error(routes):
|
||||
async def failing(*_args, **_kwargs):
|
||||
raise RuntimeError("bad")
|
||||
|
||||
routes.service.get_lora_usage_tips_by_relative_path = failing
|
||||
response = await routes.get_lora_usage_tips_by_path(DummyRequest(query={"relative_path": "path"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 500
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_get_preview_url_success(routes):
|
||||
routes.service.previews["demo"] = "http://preview"
|
||||
response = await routes.get_lora_preview_url(DummyRequest(query={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert payload == {"success": True, "preview_url": "http://preview"}
|
||||
|
||||
|
||||
async def test_get_preview_url_missing(routes):
|
||||
response = await routes.get_lora_preview_url(DummyRequest())
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_get_preview_url_not_found(routes):
|
||||
response = await routes.get_lora_preview_url(DummyRequest(query={"name": "missing"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 404
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_get_civitai_url_success(routes):
|
||||
routes.service.civitai["demo"] = {"civitai_url": "https://civitai.com"}
|
||||
response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert payload == {"success": True, "civitai_url": "https://civitai.com"}
|
||||
|
||||
|
||||
async def test_get_civitai_url_missing(routes):
|
||||
response = await routes.get_lora_civitai_url(DummyRequest())
|
||||
assert response.status == 400
|
||||
|
||||
|
||||
async def test_get_civitai_url_not_found(routes):
|
||||
response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "missing"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 404
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_get_civitai_url_error(routes):
|
||||
async def failing(*_args, **_kwargs):
|
||||
raise RuntimeError("oops")
|
||||
|
||||
routes.service.get_lora_civitai_url = failing
|
||||
response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "demo"}))
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 500
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
||||
send_mock = MagicMock()
|
||||
PromptServer.instance = SimpleNamespace(send_sync=send_mock)
|
||||
|
||||
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
||||
|
||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
|
||||
|
||||
response = await routes.get_trigger_words(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload == {"success": True}
|
||||
send_mock.assert_called_once_with(
|
||||
"trigger_word_update",
|
||||
{"id": "node", "message": "trigger-one"},
|
||||
)
|
||||
|
||||
|
||||
async def test_get_trigger_words_error(monkeypatch, routes):
|
||||
async def failing_json():
|
||||
raise RuntimeError("bad json")
|
||||
|
||||
request = DummyRequest(json_data=None)
|
||||
request.json = failing_json
|
||||
|
||||
response = await routes.get_trigger_words(request)
|
||||
payload = json.loads(response.text)
|
||||
assert response.status == 500
|
||||
assert payload["success"] is False
|
||||
222
tests/services/test_civitai_client.py
Normal file
222
tests/services/test_civitai_client.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import civitai_client as civitai_client_module
|
||||
from py.services.civitai_client import CivitaiClient
|
||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||
|
||||
|
||||
class DummyDownloader:
|
||||
def __init__(self):
|
||||
self.download_calls = []
|
||||
self.memory_calls = []
|
||||
self.request_calls = []
|
||||
|
||||
async def download_file(self, **kwargs):
|
||||
self.download_calls.append(kwargs)
|
||||
return True, kwargs["save_path"]
|
||||
|
||||
async def download_to_memory(self, url, use_auth=False):
|
||||
self.memory_calls.append({"url": url, "use_auth": use_auth})
|
||||
return True, b"bytes", {"content-type": "image/jpeg"}
|
||||
|
||||
async def make_request(self, method, url, use_auth=True):
|
||||
self.request_calls.append({"method": method, "url": url, "use_auth": use_auth})
|
||||
return True, {}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singletons():
|
||||
CivitaiClient._instance = None
|
||||
ModelMetadataProviderManager._instance = None
|
||||
yield
|
||||
CivitaiClient._instance = None
|
||||
ModelMetadataProviderManager._instance = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def downloader(monkeypatch):
|
||||
instance = DummyDownloader()
|
||||
monkeypatch.setattr(civitai_client_module, "get_downloader", AsyncMock(return_value=instance))
|
||||
return instance
|
||||
|
||||
|
||||
async def test_download_file_uses_downloader(tmp_path, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
save_dir = tmp_path / "files"
|
||||
save_dir.mkdir()
|
||||
|
||||
success, path = await client.download_file(
|
||||
url="https://example.invalid/model",
|
||||
save_dir=str(save_dir),
|
||||
default_filename="model.safetensors",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert path == str(save_dir / "model.safetensors")
|
||||
assert downloader.download_calls[0]["use_auth"] is True
|
||||
|
||||
|
||||
async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
|
||||
version_payload = {
|
||||
"modelId": 123,
|
||||
"model": {"description": "", "tags": []},
|
||||
"creator": {},
|
||||
}
|
||||
model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
if url.endswith("by-hash/hash"):
|
||||
return True, version_payload.copy()
|
||||
if url.endswith("/models/123"):
|
||||
return True, model_payload
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("hash")
|
||||
|
||||
assert error is None
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
|
||||
|
||||
async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return False, "not found"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("missing")
|
||||
|
||||
assert result is None
|
||||
assert error == "Model not found"
|
||||
|
||||
|
||||
async def test_download_preview_image_writes_file(tmp_path, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
target = tmp_path / "preview" / "image.jpg"
|
||||
|
||||
success = await client.download_preview_image("https://example.invalid/preview", str(target))
|
||||
|
||||
assert success is True
|
||||
assert target.exists()
|
||||
assert target.read_bytes() == b"bytes"
|
||||
|
||||
|
||||
async def test_download_preview_image_failure(monkeypatch, downloader):
|
||||
async def failing_download(url, use_auth=False):
|
||||
return False, b"", {}
|
||||
|
||||
downloader.download_to_memory = failing_download
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
target = "/tmp/ignored.jpg"
|
||||
|
||||
success = await client.download_preview_image("https://example.invalid/preview", target)
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_get_model_versions_success(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_versions("123")
|
||||
|
||||
assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
||||
|
||||
|
||||
async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
if url.endswith("/model-versions/7"):
|
||||
return True, {
|
||||
"modelId": 321,
|
||||
"model": {"description": ""},
|
||||
"files": [],
|
||||
}
|
||||
if url.endswith("/models/321"):
|
||||
return True, {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(version_id=7)
|
||||
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
|
||||
|
||||
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
result = await client.get_model_version()
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_get_model_version_info_handles_not_found(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return False, "not found"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_version_info("55")
|
||||
|
||||
assert result is None
|
||||
assert error == "Model not found"
|
||||
|
||||
|
||||
async def test_get_model_version_info_success(monkeypatch, downloader):
|
||||
expected = {"id": 55}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return True, expected
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_version_info("55")
|
||||
|
||||
assert result == expected
|
||||
assert error is None
|
||||
|
||||
|
||||
async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return True, {"items": [{"id": 1}, {"id": 2}]}
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_image_info("42")
|
||||
|
||||
assert result == {"id": 1}
|
||||
|
||||
|
||||
async def test_get_image_info_handles_missing(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
return True, {"items": []}
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_image_info("42")
|
||||
|
||||
assert result is None
|
||||
247
tests/services/test_download_manager.py
Normal file
247
tests/services/test_download_manager.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.download_manager import DownloadManager
|
||||
from py.services import download_manager
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.settings_manager import settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_download_manager():
|
||||
"""Ensure each test operates on a fresh singleton."""
|
||||
DownloadManager._instance = None
|
||||
yield
|
||||
DownloadManager._instance = None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_settings(monkeypatch, tmp_path):
|
||||
"""Point settings writes at a temporary directory to avoid touching real files."""
|
||||
default_settings = settings._get_default_settings()
|
||||
default_settings.update(
|
||||
{
|
||||
"default_lora_root": str(tmp_path),
|
||||
"default_checkpoint_root": str(tmp_path / "checkpoints"),
|
||||
"default_embedding_root": str(tmp_path / "embeddings"),
|
||||
"download_path_templates": {
|
||||
"lora": "{base_model}/{first_tag}",
|
||||
"checkpoint": "{base_model}/{first_tag}",
|
||||
"embedding": "{base_model}/{first_tag}",
|
||||
},
|
||||
"base_model_path_mappings": {"BaseModel": "MappedModel"},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(settings, "settings", default_settings)
|
||||
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_metadata(monkeypatch):
|
||||
class _StubMetadata:
|
||||
def __init__(self, save_path: str):
|
||||
self.file_path = save_path
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = Path(save_path).stem
|
||||
|
||||
def _factory(save_path: str):
|
||||
return _StubMetadata(save_path)
|
||||
|
||||
def _make_class():
|
||||
@staticmethod
|
||||
def from_civitai_info(_version_info, _file_info, save_path):
|
||||
return _factory(save_path)
|
||||
|
||||
return type("StubMetadata", (), {"from_civitai_info": from_civitai_info})
|
||||
|
||||
stub_class = _make_class()
|
||||
monkeypatch.setattr(download_manager, "LoraMetadata", stub_class)
|
||||
monkeypatch.setattr(download_manager, "CheckpointMetadata", stub_class)
|
||||
monkeypatch.setattr(download_manager, "EmbeddingMetadata", stub_class)
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, exists: bool = False):
|
||||
self.exists = exists
|
||||
self.calls = []
|
||||
|
||||
async def check_model_version_exists(self, version_id):
|
||||
self.calls.append(version_id)
|
||||
return self.exists
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scanners(monkeypatch):
|
||||
lora_scanner = DummyScanner()
|
||||
checkpoint_scanner = DummyScanner()
|
||||
embedding_scanner = DummyScanner()
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=embedding_scanner))
|
||||
|
||||
return SimpleNamespace(
|
||||
lora=lora_scanner,
|
||||
checkpoint=checkpoint_scanner,
|
||||
embedding=embedding_scanner,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_provider(monkeypatch):
|
||||
class DummyProvider:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def get_model_version(self, model_id, model_version_id):
|
||||
self.calls.append((model_id, model_version_id))
|
||||
return {
|
||||
"id": 42,
|
||||
"model": {"type": "LoRA", "tags": ["fantasy"]},
|
||||
"baseModel": "BaseModel",
|
||||
"creator": {"username": "Author"},
|
||||
"files": [
|
||||
{
|
||||
"primary": True,
|
||||
"downloadUrl": "https://example.invalid/file.safetensors",
|
||||
"name": "file.safetensors",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
provider = DummyProvider()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_default_metadata_provider",
|
||||
AsyncMock(return_value=provider),
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def noop_cleanup(monkeypatch):
|
||||
async def _cleanup(self, task_id):
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]["cleaned"] = True
|
||||
|
||||
monkeypatch.setattr(DownloadManager, "_cleanup_download_record", _cleanup)
|
||||
|
||||
|
||||
async def test_download_requires_identifier():
|
||||
manager = DownloadManager()
|
||||
result = await manager.download_from_civitai()
|
||||
assert result == {
|
||||
"success": False,
|
||||
"error": "Either model_id or model_version_id must be provided",
|
||||
}
|
||||
|
||||
|
||||
async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata_provider, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_execute_download(
|
||||
self,
|
||||
*,
|
||||
download_url,
|
||||
save_dir,
|
||||
metadata,
|
||||
version_info,
|
||||
relative_path,
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
):
|
||||
captured.update(
|
||||
{
|
||||
"download_url": download_url,
|
||||
"save_dir": Path(save_dir),
|
||||
"relative_path": relative_path,
|
||||
"progress_callback": progress_callback,
|
||||
"model_type": model_type,
|
||||
"download_id": download_id,
|
||||
"metadata_path": metadata.file_path,
|
||||
}
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False)
|
||||
|
||||
result = await manager.download_from_civitai(
|
||||
model_version_id=99,
|
||||
save_dir=str(tmp_path),
|
||||
use_default_paths=True,
|
||||
progress_callback=None,
|
||||
source=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "download_id" in result
|
||||
assert manager._download_tasks == {}
|
||||
assert manager._active_downloads[result["download_id"]]["status"] == "completed"
|
||||
|
||||
assert captured["relative_path"] == "MappedModel/fantasy"
|
||||
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
|
||||
assert captured["save_dir"] == expected_dir
|
||||
assert captured["model_type"] == "lora"
|
||||
|
||||
|
||||
async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider):
|
||||
scanners.lora.exists = True
|
||||
|
||||
manager = DownloadManager()
|
||||
|
||||
execute_mock = AsyncMock(return_value={"success": True})
|
||||
monkeypatch.setattr(DownloadManager, "_execute_download", execute_mock)
|
||||
|
||||
result = await manager.download_from_civitai(model_version_id=101, save_dir="/tmp")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "Model version already exists in lora library"
|
||||
assert "download_id" in result
|
||||
assert execute_mock.await_count == 0
|
||||
|
||||
|
||||
async def test_download_handles_metadata_errors(monkeypatch, scanners):
|
||||
async def failing_provider(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_default_metadata_provider",
|
||||
AsyncMock(return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None))),
|
||||
)
|
||||
|
||||
manager = DownloadManager()
|
||||
|
||||
result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "Failed to fetch model metadata"
|
||||
assert "download_id" in result
|
||||
|
||||
|
||||
async def test_download_rejects_unsupported_model_type(monkeypatch, scanners):
|
||||
class Provider:
|
||||
async def get_model_version(self, *_args, **_kwargs):
|
||||
return {
|
||||
"model": {"type": "Unsupported", "tags": []},
|
||||
"files": [],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_default_metadata_provider",
|
||||
AsyncMock(return_value=Provider()),
|
||||
)
|
||||
|
||||
manager = DownloadManager()
|
||||
|
||||
result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"].startswith("Model type")
|
||||
61
tests/services/test_settings_manager.py
Normal file
61
tests/services/test_settings_manager.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import SettingsManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||
mgr = SettingsManager()
|
||||
mgr.settings_file = str(tmp_path / "settings.json")
|
||||
return mgr
|
||||
|
||||
|
||||
def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||
monkeypatch.setenv("CIVITAI_API_KEY", "secret")
|
||||
mgr = SettingsManager()
|
||||
mgr.settings_file = str(tmp_path / "settings.json")
|
||||
|
||||
assert mgr.get("civitai_api_key") == "secret"
|
||||
|
||||
|
||||
def test_download_path_template_parses_json_string(manager):
|
||||
templates = {"lora": "{author}", "checkpoint": "{author}", "embedding": "{author}"}
|
||||
manager.settings["download_path_templates"] = json.dumps(templates)
|
||||
|
||||
template = manager.get_download_path_template("lora")
|
||||
|
||||
assert template == "{author}"
|
||||
assert isinstance(manager.settings["download_path_templates"], dict)
|
||||
|
||||
|
||||
def test_download_path_template_invalid_json(manager):
|
||||
manager.settings["download_path_templates"] = "not json"
|
||||
|
||||
template = manager.get_download_path_template("checkpoint")
|
||||
|
||||
assert template == "{base_model}/{first_tag}"
|
||||
assert manager.settings["download_path_templates"]["lora"] == "{base_model}/{first_tag}"
|
||||
|
||||
|
||||
def test_auto_set_default_roots(manager):
|
||||
manager.settings["folder_paths"] = {
|
||||
"loras": ["/loras"],
|
||||
"checkpoints": ["/checkpoints"],
|
||||
"embeddings": ["/embeddings"],
|
||||
}
|
||||
|
||||
manager._auto_set_default_roots()
|
||||
|
||||
assert manager.get("default_lora_root") == "/loras"
|
||||
assert manager.get("default_checkpoint_root") == "/checkpoints"
|
||||
assert manager.get("default_embedding_root") == "/embeddings"
|
||||
|
||||
|
||||
def test_delete_setting(manager):
|
||||
manager.set("example", 1)
|
||||
manager.delete("example")
|
||||
assert manager.get("example") is None
|
||||
84
tests/services/test_websocket_manager.py
Normal file
84
tests/services/test_websocket_manager.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.websocket_manager import WebSocketManager
|
||||
|
||||
|
||||
class DummyWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.closed = False
|
||||
|
||||
async def send_json(self, data):
|
||||
if self.closed:
|
||||
raise RuntimeError("WebSocket closed")
|
||||
self.messages.append(data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager():
|
||||
return WebSocketManager()
|
||||
|
||||
|
||||
async def test_broadcast_init_progress_adds_defaults(manager):
|
||||
ws = DummyWebSocket()
|
||||
manager._init_websockets.add(ws)
|
||||
|
||||
await manager.broadcast_init_progress({})
|
||||
|
||||
assert ws.messages == [
|
||||
{
|
||||
"stage": "processing",
|
||||
"progress": 0,
|
||||
"details": "Processing...",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_broadcast_download_progress_tracks_state(manager):
|
||||
ws = DummyWebSocket()
|
||||
download_id = "abc"
|
||||
manager._download_websockets[download_id] = ws
|
||||
|
||||
await manager.broadcast_download_progress(download_id, {"progress": 55})
|
||||
|
||||
assert ws.messages == [{"progress": 55}]
|
||||
assert manager.get_download_progress(download_id)["progress"] == 55
|
||||
|
||||
|
||||
async def test_broadcast_download_progress_missing_socket(manager):
|
||||
await manager.broadcast_download_progress("missing", {"progress": 30})
|
||||
# Progress should be stored even without a live websocket
|
||||
assert manager.get_download_progress("missing")["progress"] == 30
|
||||
|
||||
|
||||
async def test_auto_organize_progress_helpers(manager):
|
||||
payload = {"status": "processing", "progress": 10}
|
||||
await manager.broadcast_auto_organize_progress(payload)
|
||||
|
||||
assert manager.get_auto_organize_progress() == payload
|
||||
assert manager.is_auto_organize_running() is True
|
||||
|
||||
manager.cleanup_auto_organize_progress()
|
||||
assert manager.get_auto_organize_progress() is None
|
||||
assert manager.is_auto_organize_running() is False
|
||||
|
||||
|
||||
def test_cleanup_old_downloads(manager):
|
||||
now = datetime.now()
|
||||
manager._download_progress = {
|
||||
"recent": {"progress": 10, "timestamp": now},
|
||||
"stale": {"progress": 100, "timestamp": now - timedelta(hours=48)},
|
||||
}
|
||||
|
||||
manager.cleanup_old_downloads(max_age_hours=24)
|
||||
|
||||
assert "stale" not in manager._download_progress
|
||||
assert "recent" in manager._download_progress
|
||||
|
||||
|
||||
def test_generate_download_id(manager):
|
||||
download_id = manager.generate_download_id()
|
||||
assert isinstance(download_id, str)
|
||||
assert download_id
|
||||
Reference in New Issue
Block a user