mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat(misc): add civitai user model lookup
This commit is contained in:
@@ -6,6 +6,7 @@ from aiohttp import web
|
||||
|
||||
from py.routes.handlers.misc_handlers import (
|
||||
LoraCodeHandler,
|
||||
ModelLibraryHandler,
|
||||
NodeRegistry,
|
||||
NodeRegistryHandler,
|
||||
ServiceRegistryAdapter,
|
||||
@@ -266,10 +267,34 @@ async def fake_scanner_factory():
|
||||
return FakeScanner()
|
||||
|
||||
|
||||
class FakeExistenceScanner:
|
||||
def __init__(self, existing=None):
|
||||
self._existing = set(existing or [])
|
||||
|
||||
async def check_model_version_exists(self, version_id):
|
||||
return version_id in self._existing
|
||||
|
||||
async def get_model_versions_by_id(self, _model_id):
|
||||
return []
|
||||
|
||||
|
||||
class FakeMetadataProvider:
|
||||
async def get_model_versions(self, _model_id):
|
||||
return {"modelVersions": [], "name": "", "type": "lora"}
|
||||
|
||||
async def get_user_models(self, _username):
|
||||
return []
|
||||
|
||||
|
||||
class FakeUserModelsProvider(FakeMetadataProvider):
|
||||
def __init__(self, models):
|
||||
self.models = models
|
||||
self.received_usernames: list[str] = []
|
||||
|
||||
async def get_user_models(self, username):
|
||||
self.received_usernames.append(username)
|
||||
return self.models
|
||||
|
||||
|
||||
async def fake_metadata_provider_factory():
|
||||
return FakeMetadataProvider()
|
||||
@@ -339,6 +364,167 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
||||
assert set(mapping.keys()) == expected_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_filters_versions():
|
||||
models = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Model A",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 100,
|
||||
"name": "v1",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [{"url": "http://example.com/a1.jpg"}],
|
||||
},
|
||||
{
|
||||
"id": 101,
|
||||
"name": "v2",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [{"url": "http://example.com/a2.jpg"}],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Embedding",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 200,
|
||||
"name": "v1",
|
||||
"baseModel": None,
|
||||
"images": [{"url": "http://example.com/e1.jpg"}],
|
||||
},
|
||||
{
|
||||
"id": 202,
|
||||
"name": "v2",
|
||||
"baseModel": None,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "Checkpoint",
|
||||
"type": "Checkpoint",
|
||||
"tags": ["checkpoint"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 300,
|
||||
"name": "v1",
|
||||
"baseModel": "SDXL",
|
||||
"images": [],
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "Unsupported",
|
||||
"type": "Other",
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 400,
|
||||
"name": "v1",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
provider = FakeUserModelsProvider(models)
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
lora_scanner = FakeExistenceScanner({101})
|
||||
checkpoint_scanner = FakeExistenceScanner()
|
||||
embedding_scanner = FakeExistenceScanner({202})
|
||||
|
||||
async def lora_factory():
|
||||
return lora_scanner
|
||||
|
||||
async def checkpoint_factory():
|
||||
return checkpoint_scanner
|
||||
|
||||
async def embedding_factory():
|
||||
return embedding_scanner
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=lora_factory,
|
||||
get_checkpoint_scanner=checkpoint_factory,
|
||||
get_embedding_scanner=embedding_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["username"] == "pixel"
|
||||
assert payload["versions"] == [
|
||||
{
|
||||
"modelId": 1,
|
||||
"versionId": 100,
|
||||
"modelName": "Model A",
|
||||
"versionName": "v1",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"baseModel": "Flux.1",
|
||||
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||
},
|
||||
{
|
||||
"modelId": 2,
|
||||
"versionId": 200,
|
||||
"modelName": "Embedding",
|
||||
"versionName": "v1",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"baseModel": None,
|
||||
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||
},
|
||||
{
|
||||
"modelId": 3,
|
||||
"versionId": 300,
|
||||
"modelName": "Checkpoint",
|
||||
"versionName": "v1",
|
||||
"type": "Checkpoint",
|
||||
"tags": ["checkpoint"],
|
||||
"baseModel": "SDXL",
|
||||
"thumbnailUrl": None,
|
||||
},
|
||||
]
|
||||
|
||||
assert provider.received_usernames == ["pixel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_requires_username():
|
||||
provider = FakeUserModelsProvider([])
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest())
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "username" in payload["error"].lower()
|
||||
|
||||
|
||||
def test_ensure_handler_mapping_caches_result():
|
||||
call_records = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user