mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
test: Complete Phase 2 - Integration & Coverage improvements
- Create tests/integration/ directory with conftest.py fixtures - Add 7 download flow integration tests (test_download_flow.py) - Add 9 recipe flow integration tests (test_recipe_flow.py) - Add 12 ModelLifecycleService tests (exclude_model, bulk_delete, error paths) - Add 5 PersistentRecipeCache concurrent access tests - Update backend-testing-improvement-plan.md with Phase 2 completion Total: 28 new tests, all passing (51/51)
This commit is contained in:
210
tests/integration/conftest.py
Normal file
210
tests/integration/conftest.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Shared fixtures for integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_download_dir(tmp_path: Path) -> Path:
|
||||
"""Create a temporary directory for download tests."""
|
||||
download_dir = tmp_path / "downloads"
|
||||
download_dir.mkdir()
|
||||
return download_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model_file() -> bytes:
|
||||
"""Create sample model file content for testing."""
|
||||
return b"fake model data for testing purposes"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_recipe_data() -> Dict[str, Any]:
|
||||
"""Create sample recipe data for testing."""
|
||||
return {
|
||||
"id": "test-recipe-001",
|
||||
"title": "Test Recipe",
|
||||
"file_path": "/path/to/recipe.png",
|
||||
"folder": "test-folder",
|
||||
"base_model": "SD1.5",
|
||||
"fingerprint": "abc123def456",
|
||||
"created_date": 1700000000.0,
|
||||
"modified": 1700000100.0,
|
||||
"favorite": False,
|
||||
"repair_version": 1,
|
||||
"preview_nsfw_level": 0,
|
||||
"loras": [
|
||||
{"hash": "lora1hash", "file_name": "test_lora1", "strength": 0.8},
|
||||
{"hash": "lora2hash", "file_name": "test_lora2", "strength": 1.0},
|
||||
],
|
||||
"checkpoint": {"name": "model.safetensors", "hash": "cphash123"},
|
||||
"gen_params": {
|
||||
"prompt": "masterpiece, best quality, test subject",
|
||||
"negative_prompt": "low quality, blurry",
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler": "DPM++ 2M Karras",
|
||||
},
|
||||
"tags": ["test", "integration", "recipe"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket_manager():
|
||||
"""Provide a recording WebSocket manager for integration tests."""
|
||||
class RecordingWebSocketManager:
|
||||
def __init__(self):
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
self.download_progress: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
async def broadcast(self, payload: Dict[str, Any]) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
async def broadcast_download_progress(
|
||||
self, download_id: str, data: Dict[str, Any]
|
||||
) -> None:
|
||||
if download_id not in self.download_progress:
|
||||
self.download_progress[download_id] = []
|
||||
self.download_progress[download_id].append(data)
|
||||
|
||||
def get_download_progress(self, download_id: str) -> Dict[str, Any] | None:
|
||||
progress_list = self.download_progress.get(download_id, [])
|
||||
if not progress_list:
|
||||
return None
|
||||
# Return the latest progress
|
||||
latest = progress_list[-1]
|
||||
return {"download_id": download_id, **latest}
|
||||
|
||||
return RecordingWebSocketManager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scanner():
|
||||
"""Provide a mock model scanner with configurable behavior."""
|
||||
class MockScanner:
|
||||
def __init__(self):
|
||||
self._cache = MagicMock()
|
||||
self._cache.raw_data = []
|
||||
self._hash_index = MagicMock()
|
||||
self.model_type = "lora"
|
||||
self._tags_count: Dict[str, int] = {}
|
||||
self._excluded_models: List[str] = []
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False):
|
||||
return self._cache
|
||||
|
||||
async def update_single_model_cache(
|
||||
self, original_path: str, new_path: str, metadata: Dict[str, Any]
|
||||
) -> bool:
|
||||
for item in self._cache.raw_data:
|
||||
if item.get("file_path") == original_path:
|
||||
item.update(metadata)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_by_path(self, path: str) -> None:
|
||||
pass
|
||||
|
||||
return MockScanner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata_manager():
|
||||
"""Provide a mock metadata manager."""
|
||||
class MockMetadataManager:
|
||||
def __init__(self):
|
||||
self.saved_metadata: List[tuple] = []
|
||||
self.loaded_payloads: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def save_metadata(self, file_path: str, metadata: Dict[str, Any]) -> None:
|
||||
self.saved_metadata.append((file_path, metadata.copy()))
|
||||
|
||||
async def load_metadata_payload(self, file_path: str) -> Dict[str, Any]:
|
||||
return self.loaded_payloads.get(file_path, {})
|
||||
|
||||
def set_payload(self, file_path: str, payload: Dict[str, Any]) -> None:
|
||||
self.loaded_payloads[file_path] = payload
|
||||
|
||||
return MockMetadataManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_download_coordinator():
|
||||
"""Provide a mock download coordinator."""
|
||||
class MockDownloadCoordinator:
|
||||
def __init__(self):
|
||||
self.active_downloads: Dict[str, Any] = {}
|
||||
self.cancelled_downloads: List[str] = []
|
||||
self.paused_downloads: List[str] = []
|
||||
self.resumed_downloads: List[str] = []
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
self.cancelled_downloads.append(download_id)
|
||||
return {"success": True, "message": f"Download {download_id} cancelled"}
|
||||
|
||||
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||
self.paused_downloads.append(download_id)
|
||||
return {"success": True, "message": f"Download {download_id} paused"}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
self.resumed_downloads.append(download_id)
|
||||
return {"success": True, "message": f"Download {download_id} resumed"}
|
||||
|
||||
return MockDownloadCoordinator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_http_server(
|
||||
tmp_path: Path,
|
||||
) -> AsyncGenerator[tuple[str, int], None]:
|
||||
"""Create a test HTTP server that serves files from a temporary directory."""
|
||||
from aiohttp import web
|
||||
|
||||
async def handle_download(request):
|
||||
"""Handle file download requests."""
|
||||
filename = request.match_info.get("filename", "test_model.safetensors")
|
||||
file_path = tmp_path / filename
|
||||
if file_path.exists():
|
||||
return web.FileResponse(path=file_path)
|
||||
return web.Response(status=404, text="File not found")
|
||||
|
||||
async def handle_status(request):
|
||||
"""Return server status."""
|
||||
return web.json_response({"status": "ok", "server": "test"})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/download/{filename}", handle_download)
|
||||
app.router.add_get("/status", handle_status)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
|
||||
# Use port 0 to get an available port
|
||||
site = web.TCPSite(runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
|
||||
port = site._server.sockets[0].getsockname()[1]
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
yield base_url, port
|
||||
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
"""Create an event loop for async tests."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
Reference in New Issue
Block a user