mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 04:42:14 -03:00
feat(download-history): track downloaded model versions
This commit is contained in:
@@ -66,6 +66,27 @@ class FakePromptServer:
|
||||
instance = Instance()
|
||||
|
||||
|
||||
class FakeDownloadHistoryService:
|
||||
async def has_been_downloaded(self, _model_type, _version_id):
|
||||
return False
|
||||
|
||||
async def get_downloaded_version_ids(self, _model_type, _model_id):
|
||||
return []
|
||||
|
||||
async def get_downloaded_version_ids_bulk(self, _model_type, _model_ids):
|
||||
return {}
|
||||
|
||||
async def mark_downloaded(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
async def mark_not_downloaded(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
|
||||
async def fake_download_history_service_factory():
|
||||
return FakeDownloadHistoryService()
|
||||
|
||||
|
||||
class TestSettingsHandlerSnapshots:
|
||||
"""Snapshot tests for SettingsHandler responses."""
|
||||
|
||||
@@ -223,6 +244,7 @@ class TestModelLibraryHandlerSnapshots:
|
||||
get_lora_scanner=scanner_factory,
|
||||
get_checkpoint_scanner=scanner_factory,
|
||||
get_embedding_scanner=scanner_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
),
|
||||
metadata_provider_factory=lambda: None,
|
||||
)
|
||||
|
||||
@@ -438,6 +438,46 @@ async def fake_metadata_archive_manager_factory():
|
||||
return FakeMetadataArchiveManager()
|
||||
|
||||
|
||||
class FakeDownloadHistoryService:
|
||||
def __init__(self, downloaded_by_type=None):
|
||||
self.downloaded_by_type = downloaded_by_type or {}
|
||||
self.marked_downloaded: list[tuple] = []
|
||||
self.marked_not_downloaded: list[tuple] = []
|
||||
|
||||
async def has_been_downloaded(self, model_type, version_id):
|
||||
return version_id in self.downloaded_by_type.get(model_type, set())
|
||||
|
||||
async def get_downloaded_version_ids(self, model_type, model_id):
|
||||
entries = self.downloaded_by_type.get(model_type, {})
|
||||
if isinstance(entries, dict):
|
||||
return sorted(entries.get(model_id, set()))
|
||||
return []
|
||||
|
||||
async def get_downloaded_version_ids_bulk(self, model_type, model_ids):
|
||||
entries = self.downloaded_by_type.get(model_type, {})
|
||||
if not isinstance(entries, dict):
|
||||
return {}
|
||||
return {
|
||||
model_id: set(entries.get(model_id, set()))
|
||||
for model_id in model_ids
|
||||
if model_id in entries
|
||||
}
|
||||
|
||||
async def mark_downloaded(
|
||||
self, model_type, version_id, *, model_id=None, source="manual", file_path=None
|
||||
):
|
||||
self.marked_downloaded.append(
|
||||
(model_type, version_id, model_id, source, file_path)
|
||||
)
|
||||
|
||||
async def mark_not_downloaded(self, model_type, version_id):
|
||||
self.marked_not_downloaded.append((model_type, version_id))
|
||||
|
||||
|
||||
async def fake_download_history_service_factory():
|
||||
return FakeDownloadHistoryService()
|
||||
|
||||
|
||||
class RecordingRegistrar:
|
||||
def __init__(self, _app):
|
||||
self.registered_mapping = None
|
||||
@@ -452,6 +492,7 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
)
|
||||
|
||||
recorded_registrars = []
|
||||
@@ -578,6 +619,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
get_lora_scanner=lora_factory,
|
||||
get_checkpoint_scanner=checkpoint_factory,
|
||||
get_embedding_scanner=embedding_factory,
|
||||
get_downloaded_version_history_service=lambda: fake_download_history_service_factory(),
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
@@ -600,6 +642,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
"baseModel": "Flux.1",
|
||||
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||
"inLibrary": False,
|
||||
"hasBeenDownloaded": False,
|
||||
},
|
||||
{
|
||||
"modelId": 1,
|
||||
@@ -611,6 +654,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
"baseModel": "Flux.1",
|
||||
"thumbnailUrl": "http://example.com/a2.jpg",
|
||||
"inLibrary": True,
|
||||
"hasBeenDownloaded": False,
|
||||
},
|
||||
{
|
||||
"modelId": 2,
|
||||
@@ -622,6 +666,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
"baseModel": None,
|
||||
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||
"inLibrary": False,
|
||||
"hasBeenDownloaded": False,
|
||||
},
|
||||
{
|
||||
"modelId": 2,
|
||||
@@ -633,6 +678,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
"baseModel": None,
|
||||
"thumbnailUrl": None,
|
||||
"inLibrary": True,
|
||||
"hasBeenDownloaded": False,
|
||||
},
|
||||
{
|
||||
"modelId": 3,
|
||||
@@ -644,6 +690,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
"baseModel": "SDXL",
|
||||
"thumbnailUrl": None,
|
||||
"inLibrary": False,
|
||||
"hasBeenDownloaded": False,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -692,6 +739,7 @@ async def test_get_civitai_user_models_rewrites_civitai_previews():
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
@@ -727,6 +775,7 @@ async def test_get_civitai_user_models_requires_username():
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
@@ -760,6 +809,7 @@ def test_ensure_handler_mapping_caches_result():
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
),
|
||||
metadata_provider_factory=fake_metadata_provider_factory,
|
||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||
@@ -802,6 +852,7 @@ async def test_check_model_exists_returns_local_versions():
|
||||
get_lora_scanner=lora_factory,
|
||||
get_checkpoint_scanner=checkpoint_factory,
|
||||
get_embedding_scanner=embedding_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
),
|
||||
metadata_provider_factory=fake_metadata_provider_factory,
|
||||
)
|
||||
@@ -811,10 +862,94 @@ async def test_check_model_exists_returns_local_versions():
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["modelType"] == "lora"
|
||||
assert payload["versions"] == versions
|
||||
assert payload["versions"] == [
|
||||
{"versionId": 11, "name": "v1", "fileName": "model-one", "hasBeenDownloaded": True},
|
||||
{"versionId": 12, "name": "v2", "fileName": "model-two", "hasBeenDownloaded": True},
|
||||
]
|
||||
assert lora_scanner.version_calls == [5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_model_exists_returns_download_history_when_file_missing():
|
||||
history_service = FakeDownloadHistoryService({"checkpoint": {999}})
|
||||
|
||||
async def history_factory():
|
||||
return history_service
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=history_factory,
|
||||
),
|
||||
metadata_provider_factory=fake_metadata_provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.check_model_exists(
|
||||
FakeRequest(query={"modelId": "5", "modelVersionId": "999"})
|
||||
)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload == {
|
||||
"success": True,
|
||||
"exists": False,
|
||||
"modelType": "checkpoint",
|
||||
"hasBeenDownloaded": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_version_download_status_endpoints():
|
||||
history_service = FakeDownloadHistoryService({"lora": {123}})
|
||||
|
||||
async def history_factory():
|
||||
return history_service
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=history_factory,
|
||||
),
|
||||
metadata_provider_factory=fake_metadata_provider_factory,
|
||||
)
|
||||
|
||||
get_response = await handler.get_model_version_download_status(
|
||||
FakeRequest(query={"modelType": "lora", "modelVersionId": "123"})
|
||||
)
|
||||
get_payload = json.loads(get_response.text)
|
||||
assert get_payload == {
|
||||
"success": True,
|
||||
"modelType": "lora",
|
||||
"modelVersionId": 123,
|
||||
"hasBeenDownloaded": True,
|
||||
}
|
||||
|
||||
set_response = await handler.set_model_version_download_status(
|
||||
FakeRequest(
|
||||
json_data={
|
||||
"modelType": "checkpoint",
|
||||
"modelVersionId": 456,
|
||||
"modelId": 78,
|
||||
"downloaded": True,
|
||||
"filePath": "/tmp/model.safetensors",
|
||||
}
|
||||
)
|
||||
)
|
||||
set_payload = json.loads(set_response.text)
|
||||
assert set_payload == {
|
||||
"success": True,
|
||||
"modelType": "checkpoint",
|
||||
"modelVersionId": 456,
|
||||
"hasBeenDownloaded": True,
|
||||
}
|
||||
assert history_service.marked_downloaded == [
|
||||
("checkpoint", 456, 78, "manual", "/tmp/model.safetensors")
|
||||
]
|
||||
|
||||
|
||||
def test_create_handler_set_uses_provided_dependencies():
|
||||
recorded_handlers: list[dict] = []
|
||||
|
||||
@@ -845,6 +980,7 @@ def test_create_handler_set_uses_provided_dependencies():
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||
),
|
||||
metadata_provider_factory=fake_metadata_provider_factory,
|
||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||
|
||||
Reference in New Issue
Block a user