feat(download-history): track downloaded model versions

This commit is contained in:
Will Miao
2026-04-03 16:13:14 +08:00
parent 4f599aeced
commit 33a7f07558
9 changed files with 881 additions and 18 deletions

View File

@@ -438,6 +438,46 @@ async def fake_metadata_archive_manager_factory():
return FakeMetadataArchiveManager()
class FakeDownloadHistoryService:
def __init__(self, downloaded_by_type=None):
self.downloaded_by_type = downloaded_by_type or {}
self.marked_downloaded: list[tuple] = []
self.marked_not_downloaded: list[tuple] = []
async def has_been_downloaded(self, model_type, version_id):
return version_id in self.downloaded_by_type.get(model_type, set())
async def get_downloaded_version_ids(self, model_type, model_id):
entries = self.downloaded_by_type.get(model_type, {})
if isinstance(entries, dict):
return sorted(entries.get(model_id, set()))
return []
async def get_downloaded_version_ids_bulk(self, model_type, model_ids):
entries = self.downloaded_by_type.get(model_type, {})
if not isinstance(entries, dict):
return {}
return {
model_id: set(entries.get(model_id, set()))
for model_id in model_ids
if model_id in entries
}
async def mark_downloaded(
self, model_type, version_id, *, model_id=None, source="manual", file_path=None
):
self.marked_downloaded.append(
(model_type, version_id, model_id, source, file_path)
)
async def mark_not_downloaded(self, model_type, version_id):
self.marked_not_downloaded.append((model_type, version_id))
async def fake_download_history_service_factory():
return FakeDownloadHistoryService()
class RecordingRegistrar:
def __init__(self, _app):
self.registered_mapping = None
@@ -452,6 +492,7 @@ async def test_misc_routes_bind_produces_expected_handlers():
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
)
recorded_registrars = []
@@ -578,6 +619,7 @@ async def test_get_civitai_user_models_marks_library_versions():
get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory,
get_downloaded_version_history_service=lambda: fake_download_history_service_factory(),
),
metadata_provider_factory=provider_factory,
)
@@ -600,6 +642,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a1.jpg",
"inLibrary": False,
"hasBeenDownloaded": False,
},
{
"modelId": 1,
@@ -611,6 +654,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a2.jpg",
"inLibrary": True,
"hasBeenDownloaded": False,
},
{
"modelId": 2,
@@ -622,6 +666,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": None,
"thumbnailUrl": "http://example.com/e1.jpg",
"inLibrary": False,
"hasBeenDownloaded": False,
},
{
"modelId": 2,
@@ -633,6 +678,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": None,
"thumbnailUrl": None,
"inLibrary": True,
"hasBeenDownloaded": False,
},
{
"modelId": 3,
@@ -644,6 +690,7 @@ async def test_get_civitai_user_models_marks_library_versions():
"baseModel": "SDXL",
"thumbnailUrl": None,
"inLibrary": False,
"hasBeenDownloaded": False,
},
]
@@ -692,6 +739,7 @@ async def test_get_civitai_user_models_rewrites_civitai_previews():
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
),
metadata_provider_factory=provider_factory,
)
@@ -727,6 +775,7 @@ async def test_get_civitai_user_models_requires_username():
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
),
metadata_provider_factory=provider_factory,
)
@@ -760,6 +809,7 @@ def test_ensure_handler_mapping_caches_result():
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
@@ -802,6 +852,7 @@ async def test_check_model_exists_returns_local_versions():
get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
)
@@ -811,10 +862,94 @@ async def test_check_model_exists_returns_local_versions():
assert payload["success"] is True
assert payload["modelType"] == "lora"
assert payload["versions"] == versions
assert payload["versions"] == [
{"versionId": 11, "name": "v1", "fileName": "model-one", "hasBeenDownloaded": True},
{"versionId": 12, "name": "v2", "fileName": "model-two", "hasBeenDownloaded": True},
]
assert lora_scanner.version_calls == [5]
@pytest.mark.asyncio
async def test_check_model_exists_returns_download_history_when_file_missing():
history_service = FakeDownloadHistoryService({"checkpoint": {999}})
async def history_factory():
return history_service
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=history_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
)
response = await handler.check_model_exists(
FakeRequest(query={"modelId": "5", "modelVersionId": "999"})
)
payload = json.loads(response.text)
assert payload == {
"success": True,
"exists": False,
"modelType": "checkpoint",
"hasBeenDownloaded": True,
}
@pytest.mark.asyncio
async def test_model_version_download_status_endpoints():
history_service = FakeDownloadHistoryService({"lora": {123}})
async def history_factory():
return history_service
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=history_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
)
get_response = await handler.get_model_version_download_status(
FakeRequest(query={"modelType": "lora", "modelVersionId": "123"})
)
get_payload = json.loads(get_response.text)
assert get_payload == {
"success": True,
"modelType": "lora",
"modelVersionId": 123,
"hasBeenDownloaded": True,
}
set_response = await handler.set_model_version_download_status(
FakeRequest(
json_data={
"modelType": "checkpoint",
"modelVersionId": 456,
"modelId": 78,
"downloaded": True,
"filePath": "/tmp/model.safetensors",
}
)
)
set_payload = json.loads(set_response.text)
assert set_payload == {
"success": True,
"modelType": "checkpoint",
"modelVersionId": 456,
"hasBeenDownloaded": True,
}
assert history_service.marked_downloaded == [
("checkpoint", 456, 78, "manual", "/tmp/model.safetensors")
]
def test_create_handler_set_uses_provided_dependencies():
recorded_handlers: list[dict] = []
@@ -845,6 +980,7 @@ def test_create_handler_set_uses_provided_dependencies():
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
get_downloaded_version_history_service=fake_download_history_service_factory,
),
metadata_provider_factory=fake_metadata_provider_factory,
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,