Compare commits

...

2 Commits

Author SHA1 Message Date
Will Miao
727d0ef043 feat(misc): add model download status aggregation 2026-04-03 22:17:09 +08:00
Will Miao
9344d86332 test(misc): cover model existence download status 2026-04-03 22:16:09 +08:00
4 changed files with 97 additions and 3 deletions

View File

@@ -896,18 +896,49 @@ class ModelLibraryHandler:
model_type = None
versions = []
downloaded_version_ids = []
history_service = await self._get_download_history_service()
if lora_versions:
model_type = "lora"
versions = self._with_downloaded_flag(lora_versions)
downloaded_version_ids = await history_service.get_downloaded_version_ids(
model_type,
model_id,
)
elif checkpoint_versions:
model_type = "checkpoint"
versions = self._with_downloaded_flag(checkpoint_versions)
downloaded_version_ids = await history_service.get_downloaded_version_ids(
model_type,
model_id,
)
elif embedding_versions:
model_type = "embedding"
versions = self._with_downloaded_flag(embedding_versions)
downloaded_version_ids = await history_service.get_downloaded_version_ids(
model_type,
model_id,
)
else:
for candidate_type in ("lora", "checkpoint", "embedding"):
candidate_downloaded_version_ids = (
await history_service.get_downloaded_version_ids(
candidate_type,
model_id,
)
)
if candidate_downloaded_version_ids:
model_type = candidate_type
downloaded_version_ids = candidate_downloaded_version_ids
break
return web.json_response(
{"success": True, "modelType": model_type, "versions": versions}
{
"success": True,
"modelType": model_type,
"versions": versions,
"downloadedVersionIds": downloaded_version_ids,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to check model existence: %s", exc, exc_info=True)
@@ -962,7 +993,10 @@ class ModelLibraryHandler:
self, request: web.Request
) -> web.Response:
try:
data = await request.json()
if request.method == "GET":
data = request.query
else:
data = await request.json()
model_type, _ = await self._get_scanner_for_type(data.get("modelType"))
if not model_type:
return web.json_response(
@@ -979,6 +1013,13 @@ class ModelLibraryHandler:
)
downloaded = data.get("downloaded")
if isinstance(downloaded, str):
normalized_downloaded = downloaded.strip().lower()
if normalized_downloaded in {"true", "1"}:
downloaded = True
elif normalized_downloaded in {"false", "0"}:
downloaded = False
if not isinstance(downloaded, bool):
return web.json_response(
{"success": False, "error": "Parameter downloaded must be a boolean"},

View File

@@ -47,6 +47,11 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
"/api/lm/model-version-download-status",
"set_model_version_download_status",
),
RouteDefinition(
"GET",
"/api/lm/set-model-version-download-status",
"set_model_version_download_status",
),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
RouteDefinition(
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"

View File

@@ -1,6 +1,8 @@
# serializer version: 1
# name: TestModelLibraryHandlerSnapshots.test_check_model_exists_empty_response
dict({
'downloadedVersionIds': list([
]),
'modelType': None,
'success': True,
'versions': list([

View File

@@ -23,9 +23,10 @@ from py.routes.misc_routes import MiscRoutes
class FakeRequest:
def __init__(self, *, json_data=None, query=None):
def __init__(self, *, json_data=None, query=None, method="POST"):
self._json_data = json_data or {}
self.query = query or {}
self.method = method
async def json(self):
return self._json_data
@@ -869,6 +870,32 @@ async def test_check_model_exists_returns_local_versions():
assert lora_scanner.version_calls == [5]
@pytest.mark.asyncio
async def test_check_model_exists_model_id_only_does_not_call_metadata_provider():
async def metadata_provider_factory():
raise AssertionError("metadata provider should not be called for modelId-only checks")
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
),
metadata_provider_factory=metadata_provider_factory,
)
response = await handler.check_model_exists(FakeRequest(query={"modelId": "5"}))
payload = json.loads(response.text)
assert payload == {
"success": True,
"modelType": None,
"versions": [],
"downloadedVersionIds": [],
}
@pytest.mark.asyncio
async def test_check_model_exists_returns_download_history_when_file_missing():
history_service = FakeDownloadHistoryService({"checkpoint": {999}})
@@ -949,6 +976,25 @@ async def test_model_version_download_status_endpoints():
("checkpoint", 456, 78, "manual", "/tmp/model.safetensors")
]
set_get_response = await handler.set_model_version_download_status(
FakeRequest(
method="GET",
query={
"modelType": "embedding",
"modelVersionId": "789",
"modelId": "12",
"downloaded": "false",
},
)
)
set_get_payload = json.loads(set_get_response.text)
assert set_get_payload == {
"success": True,
"modelType": "embedding",
"modelVersionId": 789,
"hasBeenDownloaded": False,
}
def test_create_handler_set_uses_provided_dependencies():
recorded_handlers: list[dict] = []