mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
- Create tests/integration/ directory with conftest.py fixtures - Add 7 download flow integration tests (test_download_flow.py) - Add 9 recipe flow integration tests (test_recipe_flow.py) - Add 12 ModelLifecycleService tests (exclude_model, bulk_delete, error paths) - Add 5 PersistentRecipeCache concurrent access tests - Update backend-testing-improvement-plan.md with Phase 2 completion Total: 28 new tests, all passing (51/51)
211 lines
6.8 KiB
Python
211 lines
6.8 KiB
Python
"""Shared fixtures for integration tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any, AsyncGenerator, Dict, Generator, List
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import aiohttp
|
|
from aiohttp import web
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_download_dir(tmp_path: Path) -> Path:
|
|
"""Create a temporary directory for download tests."""
|
|
download_dir = tmp_path / "downloads"
|
|
download_dir.mkdir()
|
|
return download_dir
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_model_file() -> bytes:
|
|
"""Create sample model file content for testing."""
|
|
return b"fake model data for testing purposes"
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_recipe_data() -> Dict[str, Any]:
|
|
"""Create sample recipe data for testing."""
|
|
return {
|
|
"id": "test-recipe-001",
|
|
"title": "Test Recipe",
|
|
"file_path": "/path/to/recipe.png",
|
|
"folder": "test-folder",
|
|
"base_model": "SD1.5",
|
|
"fingerprint": "abc123def456",
|
|
"created_date": 1700000000.0,
|
|
"modified": 1700000100.0,
|
|
"favorite": False,
|
|
"repair_version": 1,
|
|
"preview_nsfw_level": 0,
|
|
"loras": [
|
|
{"hash": "lora1hash", "file_name": "test_lora1", "strength": 0.8},
|
|
{"hash": "lora2hash", "file_name": "test_lora2", "strength": 1.0},
|
|
],
|
|
"checkpoint": {"name": "model.safetensors", "hash": "cphash123"},
|
|
"gen_params": {
|
|
"prompt": "masterpiece, best quality, test subject",
|
|
"negative_prompt": "low quality, blurry",
|
|
"steps": 20,
|
|
"cfg": 7.0,
|
|
"sampler": "DPM++ 2M Karras",
|
|
},
|
|
"tags": ["test", "integration", "recipe"],
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket_manager():
|
|
"""Provide a recording WebSocket manager for integration tests."""
|
|
class RecordingWebSocketManager:
|
|
def __init__(self):
|
|
self.payloads: List[Dict[str, Any]] = []
|
|
self.download_progress: Dict[str, List[Dict[str, Any]]] = {}
|
|
|
|
async def broadcast(self, payload: Dict[str, Any]) -> None:
|
|
self.payloads.append(payload)
|
|
|
|
async def broadcast_download_progress(
|
|
self, download_id: str, data: Dict[str, Any]
|
|
) -> None:
|
|
if download_id not in self.download_progress:
|
|
self.download_progress[download_id] = []
|
|
self.download_progress[download_id].append(data)
|
|
|
|
def get_download_progress(self, download_id: str) -> Dict[str, Any] | None:
|
|
progress_list = self.download_progress.get(download_id, [])
|
|
if not progress_list:
|
|
return None
|
|
# Return the latest progress
|
|
latest = progress_list[-1]
|
|
return {"download_id": download_id, **latest}
|
|
|
|
return RecordingWebSocketManager()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_scanner():
|
|
"""Provide a mock model scanner with configurable behavior."""
|
|
class MockScanner:
|
|
def __init__(self):
|
|
self._cache = MagicMock()
|
|
self._cache.raw_data = []
|
|
self._hash_index = MagicMock()
|
|
self.model_type = "lora"
|
|
self._tags_count: Dict[str, int] = {}
|
|
self._excluded_models: List[str] = []
|
|
|
|
async def get_cached_data(self, force_refresh: bool = False):
|
|
return self._cache
|
|
|
|
async def update_single_model_cache(
|
|
self, original_path: str, new_path: str, metadata: Dict[str, Any]
|
|
) -> bool:
|
|
for item in self._cache.raw_data:
|
|
if item.get("file_path") == original_path:
|
|
item.update(metadata)
|
|
return True
|
|
return False
|
|
|
|
def remove_by_path(self, path: str) -> None:
|
|
pass
|
|
|
|
return MockScanner
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_metadata_manager():
|
|
"""Provide a mock metadata manager."""
|
|
class MockMetadataManager:
|
|
def __init__(self):
|
|
self.saved_metadata: List[tuple] = []
|
|
self.loaded_payloads: Dict[str, Dict[str, Any]] = {}
|
|
|
|
async def save_metadata(self, file_path: str, metadata: Dict[str, Any]) -> None:
|
|
self.saved_metadata.append((file_path, metadata.copy()))
|
|
|
|
async def load_metadata_payload(self, file_path: str) -> Dict[str, Any]:
|
|
return self.loaded_payloads.get(file_path, {})
|
|
|
|
def set_payload(self, file_path: str, payload: Dict[str, Any]) -> None:
|
|
self.loaded_payloads[file_path] = payload
|
|
|
|
return MockMetadataManager
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_download_coordinator():
|
|
"""Provide a mock download coordinator."""
|
|
class MockDownloadCoordinator:
|
|
def __init__(self):
|
|
self.active_downloads: Dict[str, Any] = {}
|
|
self.cancelled_downloads: List[str] = []
|
|
self.paused_downloads: List[str] = []
|
|
self.resumed_downloads: List[str] = []
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
|
self.cancelled_downloads.append(download_id)
|
|
return {"success": True, "message": f"Download {download_id} cancelled"}
|
|
|
|
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
|
self.paused_downloads.append(download_id)
|
|
return {"success": True, "message": f"Download {download_id} paused"}
|
|
|
|
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
|
self.resumed_downloads.append(download_id)
|
|
return {"success": True, "message": f"Download {download_id} resumed"}
|
|
|
|
return MockDownloadCoordinator
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_http_server(
|
|
tmp_path: Path,
|
|
) -> AsyncGenerator[tuple[str, int], None]:
|
|
"""Create a test HTTP server that serves files from a temporary directory."""
|
|
from aiohttp import web
|
|
|
|
async def handle_download(request):
|
|
"""Handle file download requests."""
|
|
filename = request.match_info.get("filename", "test_model.safetensors")
|
|
file_path = tmp_path / filename
|
|
if file_path.exists():
|
|
return web.FileResponse(path=file_path)
|
|
return web.Response(status=404, text="File not found")
|
|
|
|
async def handle_status(request):
|
|
"""Return server status."""
|
|
return web.json_response({"status": "ok", "server": "test"})
|
|
|
|
app = web.Application()
|
|
app.router.add_get("/download/{filename}", handle_download)
|
|
app.router.add_get("/status", handle_status)
|
|
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
|
|
# Use port 0 to get an available port
|
|
site = web.TCPSite(runner, "127.0.0.1", 0)
|
|
await site.start()
|
|
|
|
port = site._server.sockets[0].getsockname()[1]
|
|
base_url = f"http://127.0.0.1:{port}"
|
|
|
|
yield base_url, port
|
|
|
|
await runner.cleanup()
|
|
|
|
|
|
@pytest.fixture
|
|
def event_loop():
|
|
"""Create an event loop for async tests."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|