feat(misc): add civitai user model lookup

This commit is contained in:
pixelpaws
2025-10-09 11:49:41 +08:00
parent d997eaa429
commit 8cf762ffd3
6 changed files with 372 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ from aiohttp import web
from py.routes.handlers.misc_handlers import (
LoraCodeHandler,
ModelLibraryHandler,
NodeRegistry,
NodeRegistryHandler,
ServiceRegistryAdapter,
@@ -266,10 +267,34 @@ async def fake_scanner_factory():
return FakeScanner()
class FakeExistenceScanner:
def __init__(self, existing=None):
self._existing = set(existing or [])
async def check_model_version_exists(self, version_id):
return version_id in self._existing
async def get_model_versions_by_id(self, _model_id):
return []
class FakeMetadataProvider:
async def get_model_versions(self, _model_id):
return {"modelVersions": [], "name": "", "type": "lora"}
async def get_user_models(self, _username):
return []
class FakeUserModelsProvider(FakeMetadataProvider):
def __init__(self, models):
self.models = models
self.received_usernames: list[str] = []
async def get_user_models(self, username):
self.received_usernames.append(username)
return self.models
async def fake_metadata_provider_factory():
return FakeMetadataProvider()
@@ -339,6 +364,167 @@ async def test_misc_routes_bind_produces_expected_handlers():
assert set(mapping.keys()) == expected_names
@pytest.mark.asyncio
async def test_get_civitai_user_models_filters_versions():
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "v1",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a1.jpg"}],
},
{
"id": 101,
"name": "v2",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a2.jpg"}],
},
],
},
{
"id": 2,
"name": "Embedding",
"type": "TextualInversion",
"tags": ["embedding"],
"modelVersions": [
{
"id": 200,
"name": "v1",
"baseModel": None,
"images": [{"url": "http://example.com/e1.jpg"}],
},
{
"id": 202,
"name": "v2",
"baseModel": None,
},
],
},
{
"id": 3,
"name": "Checkpoint",
"type": "Checkpoint",
"tags": ["checkpoint"],
"modelVersions": [
{
"id": 300,
"name": "v1",
"baseModel": "SDXL",
"images": [],
}
],
},
{
"id": 4,
"name": "Unsupported",
"type": "Other",
"modelVersions": [
{
"id": 400,
"name": "v1",
}
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
lora_scanner = FakeExistenceScanner({101})
checkpoint_scanner = FakeExistenceScanner()
embedding_scanner = FakeExistenceScanner({202})
async def lora_factory():
return lora_scanner
async def checkpoint_factory():
return checkpoint_scanner
async def embedding_factory():
return embedding_scanner
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["username"] == "pixel"
assert payload["versions"] == [
{
"modelId": 1,
"versionId": 100,
"modelName": "Model A",
"versionName": "v1",
"type": "LORA",
"tags": ["style"],
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a1.jpg",
},
{
"modelId": 2,
"versionId": 200,
"modelName": "Embedding",
"versionName": "v1",
"type": "TextualInversion",
"tags": ["embedding"],
"baseModel": None,
"thumbnailUrl": "http://example.com/e1.jpg",
},
{
"modelId": 3,
"versionId": 300,
"modelName": "Checkpoint",
"versionName": "v1",
"type": "Checkpoint",
"tags": ["checkpoint"],
"baseModel": "SDXL",
"thumbnailUrl": None,
},
]
assert provider.received_usernames == ["pixel"]
@pytest.mark.asyncio
async def test_get_civitai_user_models_requires_username():
provider = FakeUserModelsProvider([])
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest())
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "username" in payload["error"].lower()
def test_ensure_handler_mapping_caches_result():
call_records = []