mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
test(routes): clean smoke test module
This commit is contained in:
85
docs/architecture/model_routes.md
Normal file
85
docs/architecture/model_routes.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Base model route architecture
|
||||
|
||||
The `BaseModelRoutes` controller centralizes HTTP endpoints that every model type
|
||||
(LoRAs, checkpoints, embeddings, etc.) share. Each handler either forwards the
|
||||
request to the injected service, delegates to a utility in
|
||||
`ModelRouteUtils`, or orchestrates long‑running operations via helper services
|
||||
such as the download or WebSocket managers. The table below lists every handler
|
||||
exposed in `py/routes/base_model_routes.py`, the collaborators it leans on, and
|
||||
any cache or WebSocket side effects implemented in
|
||||
`py/utils/routes_common.py`.
|
||||
|
||||
## Handler catalogue
|
||||
|
||||
| Endpoint(s) | Handler | Purpose | Collaborators | Cache / WebSocket side effects |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `/{prefix}` | `handle_models_page` | Renders the HTML page for a model type, populating the template from cached scanner data when available. | `settings`, `server_i18n`, `service.scanner.get_cached_data()` | Reads scanner cache to build folder list; flags initialization state without mutating cache. |
|
||||
| `/api/lm/{prefix}/list` | `get_models` | Returns paginated model metadata. | `service.get_paginated_data()`, `service.format_response()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/delete` | `delete_model` | Removes a single model from disk and cache. | `ModelRouteUtils.handle_delete_model()` | Deletes files, prunes `scanner.get_cached_data().raw_data`, calls `cache.resort()`, and updates `scanner._hash_index`. |
|
||||
| `/api/lm/{prefix}/exclude` | `exclude_model` | Marks a model as excluded so it no longer appears in listings. | `ModelRouteUtils.handle_exclude_model()` | Updates metadata, decrements `scanner._tags_count`, removes cache entry and hash index entry, and appends to `scanner._excluded_models`. |
|
||||
| `/api/lm/{prefix}/fetch-civitai` | `fetch_civitai` | Fetches metadata for a specific model from CivitAI. | `ModelRouteUtils.fetch_and_update_model()` | Uses scanner cache to find the target record and updates it via `scanner.update_single_model_cache`. |
|
||||
| `/api/lm/{prefix}/fetch-all-civitai` | `fetch_all_civitai` | Bulk refreshes metadata for models missing CivitAI info. | `ModelRouteUtils.fetch_and_update_model()`, `ws_manager.broadcast()` | Streams progress to all clients, updates cache entries, optionally resorts cached data. |
|
||||
| `/api/lm/{prefix}/relink-civitai` | `relink_civitai` | Re-associates a local file with a CivitAI entry. | `ModelRouteUtils.handle_relink_civitai()` | Updates metadata, refreshes cache via `scanner.update_single_model_cache`. |
|
||||
| `/api/lm/{prefix}/replace-preview` | `replace_preview` | Replaces the preview asset attached to a model. | `ModelRouteUtils.handle_replace_preview()` | Writes new preview file, updates metadata, and calls `scanner.update_preview_in_cache()`. |
|
||||
| `/api/lm/{prefix}/save-metadata` | `save_metadata` | Persists edits to model metadata. | `ModelRouteUtils.handle_save_metadata()` | Saves metadata file and syncs the cache entry. |
|
||||
| `/api/lm/{prefix}/add-tags` | `add_tags` | Adds or increments tags for a model. | `ModelRouteUtils.handle_add_tags()` | Mutates metadata, increments `scanner._tags_count`, and updates the cached model. |
|
||||
| `/api/lm/{prefix}/rename` | `rename_model` | Renames a model and its related assets. | `ModelRouteUtils.handle_rename_model()` | Renames files on disk, updates cache indices, refreshes metadata. |
|
||||
| `/api/lm/{prefix}/bulk-delete` | `bulk_delete_models` | Deletes multiple models in one request. | `ModelRouteUtils.handle_bulk_delete_models()` | Delegates to `scanner.bulk_delete_models()` which removes disk assets and cache records in bulk. |
|
||||
| `/api/lm/{prefix}/verify-duplicates` | `verify_duplicates` | Confirms that a list of files share the same hash. | `ModelRouteUtils.handle_verify_duplicates()` | Recalculates hashes, updates metadata, and patches cache entries when stored hashes change. |
|
||||
| `/api/lm/{prefix}/top-tags` | `get_top_tags` | Returns the most frequently used tags. | `service.get_top_tags()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/base-models` | `get_base_models` | Lists base models referenced by this model type. | `service.get_base_models()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/scan` | `scan_models` | Triggers a rescan of the filesystem. | `service.scan_models()` | Scanner rebuilds its cache as part of the service workflow. |
|
||||
| `/api/lm/{prefix}/roots` | `get_model_roots` | Enumerates root directories searched for this model type. | `service.get_model_roots()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/folders` | `get_folders` | Returns cached folder summaries. | `service.scanner.get_cached_data()` | Reads cached structure without mutation. |
|
||||
| `/api/lm/{prefix}/folder-tree` | `get_folder_tree` | Builds a nested folder tree of cached items. | `service.scanner.get_cached_data()` | Reads cache; does not mutate. |
|
||||
| `/api/lm/{prefix}/unified-folder-tree` | `get_unified_folder_tree` | Returns a tree aggregating all roots. | `service.scanner.get_cached_data()` | Reads cache; does not mutate. |
|
||||
| `/api/lm/{prefix}/find-duplicates` | `find_duplicate_models` | Finds duplicate hashes within the cache. | `service.scanner.get_duplicates()`, `service.scanner.get_hash_by_filename()` | Uses cache data to assemble duplicate groups; no mutation. |
|
||||
| `/api/lm/{prefix}/find-filename-conflicts` | `find_filename_conflicts` | Groups models that share a filename across directories. | `service.scanner.get_filename_conflicts()`, `service.get_path_by_hash()` | Reads cache while formatting results. |
|
||||
| `/api/lm/{prefix}/get-notes` | `get_model_notes` | Retrieves saved notes for a model. | `service.get_model_notes()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/preview-url` | `get_model_preview_url` | Resolves the static preview URL for a model. | `service.get_model_preview_url()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/civitai-url` | `get_model_civitai_url` | Returns the CivitAI permalink for a model. | `service.get_model_civitai_url()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/metadata` | `get_model_metadata` | Loads the raw metadata payload for a model. | `service.get_model_metadata()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/model-description` | `get_model_description` | Returns a formatted description for the UI. | `service.get_model_description()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/relative-paths` | `get_relative_paths` | Provides filesystem auto-complete suggestions. | `service.get_relative_paths()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/civitai/versions/{model_id}` | `get_civitai_versions` | Lists remote versions and indicates which exist locally. | `get_default_metadata_provider()`, `self.service.has_hash()`, `self.service.get_path_by_hash()` | Read-only; consults cache/service indices to mark local availability. |
|
||||
| `/api/lm/{prefix}/civitai/model/version/{modelVersionId}` | `get_civitai_model_by_version` | Fetches detailed metadata for a specific CivitAI version. | `get_default_metadata_provider()` | None (read-only). |
|
||||
| `/api/lm/{prefix}/civitai/model/hash/{hash}` | `get_civitai_model_by_hash` | Fetches CivitAI details using a hash. | `get_default_metadata_provider()` | None (read-only). |
|
||||
| `/api/lm/download-model` (POST) & `/api/lm/download-model-get` (GET) | `download_model`, `download_model_get` | Starts a download through the shared download manager. | `ModelRouteUtils.handle_download_model()`, `ServiceRegistry.get_download_manager()` | The helper broadcasts download progress via `ws_manager.broadcast_download_progress()` and stores state in `ws_manager._download_progress`. |
|
||||
| `/api/lm/cancel-download-get` | `cancel_download_get` | Cancels an active download. | `ModelRouteUtils.handle_cancel_download()` | Broadcasts a cancellation message via `ws_manager.broadcast_download_progress()` and prunes download progress entries. |
|
||||
| `/api/lm/download-progress/{download_id}` | `get_download_progress` | Reports cached download progress for a download ID. | `ws_manager.get_download_progress()` | Read-only view of cached progress. |
|
||||
| `/api/lm/{prefix}/move_model` | `move_model` | Moves a model to a new folder. | `ModelMoveService.move_model()` | File operations performed by the injected service may update scanner caches downstream. |
|
||||
| `/api/lm/{prefix}/move_models_bulk` | `move_models_bulk` | Bulk move models to a new location. | `ModelMoveService.move_models_bulk()` | File operations delegated to the service. |
|
||||
| `/api/lm/{prefix}/auto-organize` (GET/POST) | `auto_organize_models` | Launches auto-organization for models, optionally limited to selected files. | `ModelFileService.auto_organize_models()`, `ws_manager.get_auto_organize_lock()`, `WebSocketProgressCallback` | Uses a shared asyncio lock, streams progress through `ws_manager.broadcast_auto_organize_progress()`, and relies on `ws_manager.is_auto_organize_running()` state. |
|
||||
| `/api/lm/{prefix}/auto-organize-progress` | `get_auto_organize_progress` | Polls the latest auto-organize progress snapshot. | `ws_manager.get_auto_organize_progress()` | Read-only view of the WebSocket manager’s cached progress. |
|
||||
|
||||
## Shared utility side effects
|
||||
|
||||
The delegated helpers in `ModelRouteUtils` encapsulate most cache and
|
||||
WebSocket mutations. The smoke tests in this repository exercise the
|
||||
following contracts from `py/utils/routes_common.py`:
|
||||
|
||||
* `handle_delete_model` removes matching records from
|
||||
`scanner.get_cached_data().raw_data`, awaits `cache.resort()`, and calls
|
||||
`scanner._hash_index.remove_by_path()` when an index is present before
|
||||
returning a success payload.
|
||||
* `handle_replace_preview` writes a new preview file, persists metadata via
|
||||
`MetadataManager.save_metadata()`, and then invokes
|
||||
`scanner.update_preview_in_cache()` with the normalized preview path and
|
||||
NSFW level so downstream requests surface the updated asset.
|
||||
* `handle_download_model` acquires the shared download manager from
|
||||
`ServiceRegistry`, injects a WebSocket progress callback, and relies on
|
||||
`ws_manager.broadcast_download_progress()` to update the cached progress map
|
||||
that `get_download_progress` later reads.
|
||||
* `handle_bulk_delete_models`, `handle_add_tags`, `handle_exclude_model`, and
|
||||
`handle_verify_duplicates` all mutate scanner-maintained collections (hash
|
||||
indices, tag counts, exclusion lists, or cached metadata) so route handlers
|
||||
can stay thin while cache consistency remains centralized in the utility
|
||||
module.
|
||||
* `ws_manager.broadcast_auto_organize_progress()` stores the latest progress
|
||||
snapshot consumed by `get_auto_organize_progress`, while
|
||||
`ws_manager.broadcast()` is used to notify clients during CivitAI bulk
|
||||
refreshes and other background operations.
|
||||
|
||||
Keeping these side effects in mind is essential when refactoring route logic:
|
||||
any replacement must continue to honor the implicit contracts the utilities
|
||||
expect from scanners, caches, and the WebSocket manager.
|
||||
154
tests/conftest.py
Normal file
154
tests/conftest.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHashIndex:
|
||||
"""Minimal hash index stub mirroring the scanner contract."""
|
||||
|
||||
removed_paths: List[str] = field(default_factory=list)
|
||||
|
||||
def remove_by_path(self, path: str) -> None:
|
||||
self.removed_paths.append(path)
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Cache object with the attributes consumed by ``ModelRouteUtils``."""
|
||||
|
||||
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
|
||||
self.raw_data: List[Dict[str, Any]] = list(items or [])
|
||||
self.resort_calls = 0
|
||||
|
||||
async def resort(self) -> None:
|
||||
self.resort_calls += 1
|
||||
# ``ModelRouteUtils`` expects the coroutine interface but does not
|
||||
# rely on the return value.
|
||||
|
||||
|
||||
class MockScanner:
|
||||
"""Scanner double that exposes the attributes used by route utilities."""
|
||||
|
||||
def __init__(self, cache: Optional[MockCache] = None, hash_index: Optional[MockHashIndex] = None):
|
||||
self._cache = cache or MockCache()
|
||||
self._hash_index = hash_index or MockHashIndex()
|
||||
self._tags_count: Dict[str, int] = {}
|
||||
self._excluded_models: List[str] = []
|
||||
self.updated_models: List[Dict[str, Any]] = []
|
||||
self.preview_updates: List[Dict[str, Any]] = []
|
||||
self.bulk_deleted: List[Sequence[str]] = []
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False):
|
||||
return self._cache
|
||||
|
||||
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict[str, Any]) -> bool:
|
||||
self.updated_models.append({
|
||||
"original_path": original_path,
|
||||
"new_path": new_path,
|
||||
"metadata": metadata,
|
||||
})
|
||||
for item in self._cache.raw_data:
|
||||
if item.get("file_path") == original_path:
|
||||
item.update(metadata)
|
||||
return True
|
||||
|
||||
async def update_preview_in_cache(self, model_path: str, preview_path: str, nsfw_level: int) -> bool:
|
||||
self.preview_updates.append({
|
||||
"model_path": model_path,
|
||||
"preview_path": preview_path,
|
||||
"nsfw_level": nsfw_level,
|
||||
})
|
||||
for item in self._cache.raw_data:
|
||||
if item.get("file_path") == model_path:
|
||||
item["preview_url"] = preview_path
|
||||
item["preview_nsfw_level"] = nsfw_level
|
||||
return True
|
||||
|
||||
async def bulk_delete_models(self, file_paths: Sequence[str]) -> Dict[str, Any]:
|
||||
self.bulk_deleted.append(tuple(file_paths))
|
||||
self._cache.raw_data = [item for item in self._cache.raw_data if item.get("file_path") not in file_paths]
|
||||
await self._cache.resort()
|
||||
for path in file_paths:
|
||||
self._hash_index.remove_by_path(path)
|
||||
return {"success": True, "deleted": list(file_paths)}
|
||||
|
||||
|
||||
class MockModelService:
|
||||
"""Service stub consumed by the shared routes."""
|
||||
|
||||
def __init__(self, scanner: MockScanner):
|
||||
self.scanner = scanner
|
||||
self.model_type = "test-model"
|
||||
self.paginated_items: List[Dict[str, Any]] = []
|
||||
self.formatted: List[Dict[str, Any]] = []
|
||||
|
||||
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||
items = [dict(item) for item in self.paginated_items]
|
||||
total = len(items)
|
||||
page = params.get("page", 1)
|
||||
page_size = params.get("page_size", 20)
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": max(1, (total + page_size - 1) // page_size),
|
||||
}
|
||||
|
||||
async def format_response(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
formatted = {**item, "formatted": True}
|
||||
self.formatted.append(formatted)
|
||||
return formatted
|
||||
|
||||
# Convenience helpers used by assorted routes. They are no-ops for the
|
||||
# smoke tests but document the expected surface area of the real services.
|
||||
def get_model_roots(self) -> List[str]:
|
||||
return ["."]
|
||||
|
||||
async def scan_models(self, *_, **__): # pragma: no cover - behaviour exercised via mocks
|
||||
return None
|
||||
|
||||
async def get_model_notes(self, *_args, **_kwargs): # pragma: no cover
|
||||
return None
|
||||
|
||||
async def get_model_preview_url(self, *_args, **_kwargs): # pragma: no cover
|
||||
return ""
|
||||
|
||||
async def get_model_civitai_url(self, *_args, **_kwargs): # pragma: no cover
|
||||
return {"civitai_url": ""}
|
||||
|
||||
async def get_model_metadata(self, *_args, **_kwargs): # pragma: no cover
|
||||
return {}
|
||||
|
||||
async def get_model_description(self, *_args, **_kwargs): # pragma: no cover
|
||||
return ""
|
||||
|
||||
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
|
||||
return []
|
||||
|
||||
def has_hash(self, *_args, **_kwargs): # pragma: no cover
|
||||
return False
|
||||
|
||||
def get_path_by_hash(self, *_args, **_kwargs): # pragma: no cover
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hash_index() -> MockHashIndex:
|
||||
return MockHashIndex()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cache() -> MockCache:
|
||||
return MockCache()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockScanner:
|
||||
return MockScanner(cache=mock_cache, hash_index=mock_hash_index)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service(mock_scanner: MockScanner) -> MockModelService:
|
||||
return MockModelService(scanner=mock_scanner)
|
||||
228
tests/routes/test_base_model_routes_smoke.py
Normal file
228
tests/routes/test_base_model_routes_smoke.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import types
|
||||
|
||||
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
||||
sys.modules.setdefault("folder_paths", folder_paths_stub)
|
||||
|
||||
import pytest
|
||||
from aiohttp import FormData, web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
PY_PACKAGE_PATH = REPO_ROOT / "py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"py_local",
|
||||
PY_PACKAGE_PATH / "__init__.py",
|
||||
submodule_search_locations=[str(PY_PACKAGE_PATH)],
|
||||
)
|
||||
py_local = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None # for mypy/static analyzers
|
||||
spec.loader.exec_module(py_local)
|
||||
sys.modules.setdefault("py_local", py_local)
|
||||
|
||||
from py_local.routes.base_model_routes import BaseModelRoutes
|
||||
from py_local.services.service_registry import ServiceRegistry
|
||||
from py_local.services.websocket_manager import ws_manager
|
||||
from py_local.utils.routes_common import ExifUtils
|
||||
from py_local.config import config
|
||||
|
||||
|
||||
class DummyRoutes(BaseModelRoutes):
|
||||
template_name = "dummy.html"
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
|
||||
return None
|
||||
|
||||
|
||||
async def create_test_client(service) -> TestClient:
|
||||
routes = DummyRoutes(service)
|
||||
app = web.Application()
|
||||
routes.setup_routes(app, "test-models")
|
||||
|
||||
server = TestServer(app)
|
||||
client = TestClient(server)
|
||||
await client.start_server()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_ws_manager_state():
|
||||
ws_manager.cleanup_auto_organize_progress()
|
||||
ws_manager._download_progress.clear()
|
||||
yield
|
||||
ws_manager.cleanup_auto_organize_progress()
|
||||
ws_manager._download_progress.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_manager_stub():
|
||||
class FakeDownloadManager:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def download_from_civitai(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
await kwargs["progress_callback"](42)
|
||||
return {"success": True, "path": "/tmp/model.safetensors"}
|
||||
|
||||
stub = FakeDownloadManager()
|
||||
previous = ServiceRegistry._services.get("download_manager")
|
||||
asyncio.run(ServiceRegistry.register_service("download_manager", stub))
|
||||
try:
|
||||
yield stub
|
||||
finally:
|
||||
if previous is not None:
|
||||
ServiceRegistry._services["download_manager"] = previous
|
||||
else:
|
||||
ServiceRegistry._services.pop("download_manager", None)
|
||||
|
||||
|
||||
def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
||||
mock_service.paginated_items = [{"file_path": "/tmp/demo.safetensors", "name": "Demo"}]
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.get("/api/lm/test-models/list")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["items"] == [{"file_path": "/tmp/demo.safetensors", "name": "Demo", "formatted": True}]
|
||||
assert payload["total"] == 1
|
||||
assert mock_service.formatted == payload["items"]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path):
|
||||
model_path = tmp_path / "sample.safetensors"
|
||||
model_path.write_bytes(b"model")
|
||||
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/test-models/delete",
|
||||
json={"file_path": str(model_path)},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
assert mock_scanner._cache.raw_data == []
|
||||
assert mock_scanner._cache.resort_calls == 1
|
||||
assert mock_scanner._hash_index.removed_paths == [str(model_path)]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
assert not model_path.exists()
|
||||
|
||||
|
||||
def test_replace_preview_writes_file_and_updates_cache(
|
||||
mock_service,
|
||||
mock_scanner,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
model_path = tmp_path / "preview-model.safetensors"
|
||||
model_path.write_bytes(b"model")
|
||||
metadata_path = tmp_path / "preview-model.metadata.json"
|
||||
metadata_path.write_text(json.dumps({"file_path": str(model_path)}))
|
||||
|
||||
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
ExifUtils,
|
||||
"optimize_image",
|
||||
staticmethod(lambda image_data, **_: (image_data, ".webp")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"get_preview_static_url",
|
||||
lambda preview_path: f"/static/{Path(preview_path).name}",
|
||||
)
|
||||
|
||||
form = FormData()
|
||||
form.add_field("preview_file", b"binary-data", filename="preview.png", content_type="image/png")
|
||||
form.add_field("model_path", str(model_path))
|
||||
form.add_field("nsfw_level", "2")
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post("/api/lm/test-models/replace-preview", data=form)
|
||||
payload = await response.json()
|
||||
|
||||
expected_preview = str((tmp_path / "preview-model.webp")).replace(os.sep, "/")
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
assert payload["preview_url"] == "/static/preview-model.webp"
|
||||
assert Path(expected_preview).exists()
|
||||
assert mock_scanner.preview_updates[-1]["preview_path"] == expected_preview
|
||||
|
||||
updated_metadata = json.loads(metadata_path.read_text())
|
||||
assert updated_metadata["preview_url"] == expected_preview
|
||||
assert updated_metadata["preview_nsfw_level"] == 2
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_download_model_invokes_download_manager(
|
||||
mock_service,
|
||||
download_manager_stub,
|
||||
tmp_path: Path,
|
||||
):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/download-model",
|
||||
json={"model_id": 1, "model_root": str(tmp_path)},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
assert download_manager_stub.calls
|
||||
|
||||
call_args = download_manager_stub.calls[0]
|
||||
assert call_args["model_id"] == 1
|
||||
assert call_args["download_id"] == payload["download_id"]
|
||||
progress = ws_manager.get_download_progress(payload["download_id"])
|
||||
assert progress is not None
|
||||
assert progress["progress"] == 42
|
||||
ws_manager.cleanup_download_progress(payload["download_id"])
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
await ws_manager.broadcast_auto_organize_progress({"status": "processing", "percent": 50})
|
||||
|
||||
response = await client.get("/api/lm/test-models/auto-organize-progress")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
Reference in New Issue
Block a user