mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-05 00:41:17 -03:00
feat(agent): add LLM-powered metadata enrichment system with AgentCLI and PostProcessor
Introduce an agent skill framework for LLM-driven metadata enrichment: - AgentCLI (py/agent_cli/): in-process wrappers around internal services using standard relative imports, eliminating the need for sys.path hacks - LLMService: centralized BYOK (bring-your-own-key) LLM client supporting OpenAI, Ollama, and custom OpenAI-compatible endpoints - PostProcessor: deterministic engine that applies LLM output via AgentCLI (replaces old handler.py + _BASE_MODEL_ALIASES approach) - SkillRegistry: filesystem-based skill discovery (skill.yaml + prompt.md) - AgentService: orchestrates skill execution with WebSocket progress - Frontend AgentManager: WebSocket listeners, skill execution, config UI - Context menu entries (single + bulk) for "Enrich Metadata (Agent)" - Settings UI for AI Provider configuration (BYOK) - Full i18n support across 9 locales Bug fixes found during review: - aiohttp.web.json_response: status_code= -> status= - settings_modal cancelEditApiKey: wrong argument position - AgentManager.isLlmConfigured: allow Ollama without API key - PostProcessor._merge_tags: lowercase all tags to match TagUpdateService
This commit is contained in:
0
tests/agent_cli/__init__.py
Normal file
0
tests/agent_cli/__init__.py
Normal file
317
tests/agent_cli/test_agent_cli.py
Normal file
317
tests/agent_cli/test_agent_cli.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Tests for the AgentCLI module (py/agent_cli/).
|
||||
|
||||
All tests mock the underlying services (scanner, MetadataManager, downloader)
|
||||
since the AgentCLI is a thin delegation layer.
|
||||
|
||||
Mock targets must match where imports are resolved inside each function
|
||||
(lazy imports via ``from X import Y`` inside function body).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.agent_cli import (
|
||||
list_base_models,
|
||||
read_metadata,
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Helpers
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class MockCache:
|
||||
def __init__(self, raw_data: list[dict] | None = None):
|
||||
self.raw_data = raw_data or []
|
||||
|
||||
|
||||
class MockScanner:
|
||||
"""Simulates a ModelScanner for testing."""
|
||||
|
||||
def __init__(self, raw_data: list[dict] | None = None):
|
||||
self._raw_data = raw_data or []
|
||||
self.update_single_model_cache = mock.AsyncMock(return_value=True)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return MockCache(self._raw_data)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# list_base_models -- imports ServiceRegistry internally
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestListBaseModels:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_cache(self):
|
||||
scanner = MockScanner([])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merges_all_scanners(self):
|
||||
lora_scanner = MockScanner([
|
||||
{"base_model": "SDXL 1.0"},
|
||||
{"base_model": "Flux.1 D"},
|
||||
{"base_model": "SDXL 1.0"},
|
||||
])
|
||||
ckpt_scanner = MockScanner([
|
||||
{"base_model": "SDXL 1.0"},
|
||||
{"base_model": "SD 1.5"},
|
||||
])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=lora_scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=ckpt_scanner),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == ["SDXL 1.0", "Flux.1 D", "SD 1.5"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit(self):
|
||||
scanner = MockScanner([
|
||||
{"base_model": "A"}, {"base_model": "B"}, {"base_model": "C"},
|
||||
])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models(limit=2)
|
||||
assert result == ["A", "B"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_scanners_return_none(self):
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=None),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_empty_or_missing_base_model(self):
|
||||
scanner = MockScanner([
|
||||
{"base_model": "SDXL 1.0"},
|
||||
{"file_name": "foo.safetensors"}, # no base_model key
|
||||
{"base_model": ""}, # empty
|
||||
])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == ["SDXL 1.0"]
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# read_metadata -- imports MetadataManager from py.utils.metadata_manager
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestReadMetadata:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegates_to_metadata_manager(self):
|
||||
fake = {"file_name": "test", "base_model": "SDXL 1.0"}
|
||||
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
|
||||
mm.load_metadata_payload = mock.AsyncMock(return_value=fake)
|
||||
result = await read_metadata("/p.safetensors")
|
||||
assert result == fake
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_empty_dict(self):
|
||||
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
|
||||
mm.load_metadata_payload = mock.AsyncMock(side_effect=ValueError("x"))
|
||||
result = await read_metadata("/p.safetensors")
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_coerces_to_empty_dict(self):
|
||||
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
|
||||
mm.load_metadata_payload = mock.AsyncMock(return_value=None)
|
||||
result = await read_metadata("/p.safetensors")
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# apply_metadata_updates -- uses read_metadata + MetadataManager.save_metadata
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestApplyMetadataUpdates:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_field(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
mock_read.return_value = {"base_model": "", "tags": []}
|
||||
mm.save_metadata = mock.AsyncMock(return_value=True)
|
||||
updated = await apply_metadata_updates(
|
||||
"/p.safetensors", {"base_model": "Flux.1 D"}
|
||||
)
|
||||
assert updated == ["base_model"]
|
||||
mm.save_metadata.assert_awaited_once_with(
|
||||
"/p.safetensors", {"base_model": "Flux.1 D", "tags": []},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_value_unchanged(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
mock_read.return_value = {"base_model": "Flux.1 D"}
|
||||
updated = await apply_metadata_updates(
|
||||
"/p.safetensors", {"base_model": "Flux.1 D"}
|
||||
)
|
||||
assert updated == []
|
||||
mm.save_metadata.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_fields(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
mm.save_metadata = mock.AsyncMock(return_value=True)
|
||||
mock_read.return_value = {
|
||||
"base_model": "", "modelDescription": "", "tags": [],
|
||||
}
|
||||
updated = await apply_metadata_updates(
|
||||
"/p.safetensors",
|
||||
{"base_model": "SDXL 1.0", "modelDescription": "A", "tags": ["flux"]},
|
||||
)
|
||||
assert sorted(updated) == sorted(["base_model", "modelDescription", "tags"])
|
||||
saved = mm.save_metadata.call_args[0][1]
|
||||
assert saved["base_model"] == "SDXL 1.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_updates_noop(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata"),
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
updated = await apply_metadata_updates("/p.safetensors", {})
|
||||
assert updated == []
|
||||
mm.save_metadata.assert_not_called()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# download_preview -- imports get_downloader + ExifUtils
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestDownloadPreview:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_url_returns_false(self, tmp_path):
|
||||
mp = tmp_path / "m.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
assert await download_preview(str(mp), "") is False
|
||||
assert await download_preview(str(mp), " ") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_download_and_optimise(self, tmp_path):
|
||||
mp = tmp_path / "t.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
with (
|
||||
mock.patch("py.services.downloader.get_downloader") as get_dl,
|
||||
mock.patch("py.utils.exif_utils.ExifUtils") as exif,
|
||||
):
|
||||
dl = mock.AsyncMock()
|
||||
dl.download_to_memory = mock.AsyncMock(return_value=(True, b"raw", {}))
|
||||
get_dl.return_value = dl
|
||||
exif.optimize_image.return_value = (b"optimized_webp", {})
|
||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||
assert result is True
|
||||
assert (tmp_path / "t.webp").exists()
|
||||
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_returns_false(self, tmp_path):
|
||||
mp = tmp_path / "t.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
with mock.patch("py.services.downloader.get_downloader") as get_dl:
|
||||
dl = mock.AsyncMock()
|
||||
dl.download_to_memory = mock.AsyncMock(return_value=(False, None, {}))
|
||||
dl.download_file = mock.AsyncMock(return_value=(False, None))
|
||||
get_dl.return_value = dl
|
||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||
assert result is False
|
||||
assert not (tmp_path / "t.webp").exists()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# refresh_cache -- uses _find_scanner_for_model (ServiceRegistry)
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestRefreshCache:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_found_and_refreshed(self):
|
||||
scanner = MockScanner([{"file_path": "/some/path.safetensors"}])
|
||||
with (
|
||||
mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
),
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
):
|
||||
mock_read.return_value = {"base_model": "SDXL 1.0"}
|
||||
result = await refresh_cache("/some/path.safetensors")
|
||||
assert result is True
|
||||
scanner.update_single_model_cache.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_in_any_scanner(self):
|
||||
scanner = MockScanner([])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await refresh_cache("/nonexistent/path.safetensors")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_metadata_returns_false(self):
|
||||
scanner = MockScanner([{"file_path": "/some/path.safetensors"}])
|
||||
with (
|
||||
mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
),
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
):
|
||||
mock_read.return_value = {}
|
||||
result = await refresh_cache("/some/path.safetensors")
|
||||
assert result is False
|
||||
237
tests/services/test_llm_service.py
Normal file
237
tests/services/test_llm_service.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for the LLMService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.llm_service import LLMService
|
||||
from py.services.errors import LLMNotConfiguredError, LLMRateLimitError, LLMResponseError
|
||||
|
||||
|
||||
class MockSettings:
|
||||
"""Minimal settings mock for LLMService tests."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._data = {
|
||||
"llm_enabled": False,
|
||||
"llm_provider": "openai",
|
||||
"llm_api_key": "",
|
||||
"llm_api_base": "",
|
||||
"llm_model": "",
|
||||
}
|
||||
self._data.update(kwargs)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._data.get(key, default)
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock aiohttp response."""
|
||||
|
||||
def __init__(self, status, json_data=None, text_data="", headers=None):
|
||||
self.status = status
|
||||
self._json_data = json_data
|
||||
self._text_data = text_data
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class MockSession:
|
||||
"""Mock aiohttp ClientSession."""
|
||||
|
||||
def __init__(self, response):
|
||||
self._response = response
|
||||
self.closed = False
|
||||
|
||||
def post(self, url, json=None, headers=None):
|
||||
self.last_url = url
|
||||
self.last_json = json
|
||||
self.last_headers = headers
|
||||
return self._response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_service():
|
||||
"""Create an LLMService with mock settings."""
|
||||
LLMService.reset_instance()
|
||||
settings = MockSettings(
|
||||
llm_enabled=True,
|
||||
llm_provider="openai",
|
||||
llm_api_key="sk-test-key",
|
||||
llm_api_base="",
|
||||
llm_model="gpt-4o-mini",
|
||||
)
|
||||
return LLMService(settings)
|
||||
|
||||
|
||||
class TestLLMServiceConfiguration:
|
||||
def test_is_configured_when_enabled_with_key_and_model(self, llm_service):
|
||||
assert llm_service.is_configured() is True
|
||||
|
||||
def test_not_configured_when_disabled(self):
|
||||
settings = MockSettings(
|
||||
llm_enabled=False, llm_api_key="sk-test", llm_model="gpt-4o"
|
||||
)
|
||||
service = LLMService(settings)
|
||||
# Lenient: model + API key is treated as configured even without
|
||||
# the toggle, because the user clearly intends to use the feature.
|
||||
assert service.is_configured() is True
|
||||
|
||||
def test_not_configured_without_model(self):
|
||||
settings = MockSettings(llm_enabled=True, llm_api_key="sk-test", llm_model="")
|
||||
service = LLMService(settings)
|
||||
assert service.is_configured() is False
|
||||
|
||||
def test_not_configured_without_api_key_for_openai(self):
|
||||
settings = MockSettings(llm_enabled=True, llm_api_key="", llm_model="gpt-4o")
|
||||
service = LLMService(settings)
|
||||
assert service.is_configured() is False
|
||||
|
||||
def test_ollama_configured_without_api_key(self):
|
||||
settings = MockSettings(
|
||||
llm_enabled=True, llm_provider="ollama", llm_api_key="", llm_model="llama3"
|
||||
)
|
||||
service = LLMService(settings)
|
||||
assert service.is_configured() is True
|
||||
|
||||
def test_resolve_api_base_openai_default(self, llm_service):
|
||||
assert llm_service._resolve_api_base("openai", "") == "https://api.openai.com/v1"
|
||||
|
||||
def test_resolve_api_base_ollama_default(self, llm_service):
|
||||
assert llm_service._resolve_api_base("ollama", "") == "http://localhost:11434/v1"
|
||||
|
||||
def test_resolve_api_base_custom_override(self, llm_service):
|
||||
assert llm_service._resolve_api_base("custom", "https://my.api.com/v1/") == "https://my.api.com/v1"
|
||||
|
||||
def test_ensure_configured_raises_when_disabled(self):
|
||||
settings = MockSettings(llm_enabled=False)
|
||||
service = LLMService(settings)
|
||||
with pytest.raises(LLMNotConfiguredError):
|
||||
service._ensure_configured()
|
||||
|
||||
def test_ensure_configured_raises_without_model(self):
|
||||
settings = MockSettings(llm_enabled=True, llm_api_key="sk-test", llm_model="")
|
||||
service = LLMService(settings)
|
||||
with pytest.raises(LLMNotConfiguredError):
|
||||
service._ensure_configured()
|
||||
|
||||
|
||||
class TestLLMServiceChatCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_success(self, llm_service):
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
json_data={
|
||||
"choices": [{"message": {"content": "Hello!"}}],
|
||||
"usage": {"total_tokens": 10},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await llm_service.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
assert result["content"] == "Hello!"
|
||||
assert result["usage"]["total_tokens"] == 10
|
||||
assert result["model"] == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_not_configured(self):
|
||||
settings = MockSettings(llm_enabled=False)
|
||||
service = LLMService(settings)
|
||||
with pytest.raises(LLMNotConfiguredError):
|
||||
await service.chat_completion(messages=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_http_error(self, llm_service):
|
||||
mock_response = MockResponse(500, text_data="Internal Server Error")
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMResponseError, match="HTTP 500"):
|
||||
await llm_service.chat_completion(messages=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_rate_limit(self, llm_service):
|
||||
mock_response = MockResponse(429, text_data="Rate limited", headers={"Retry-After": "0"})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
await llm_service.chat_completion(
|
||||
messages=[], retry_on_rate_limit=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_bad_response_structure(self, llm_service):
|
||||
mock_response = MockResponse(200, json_data={"unexpected": "data"})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMResponseError, match="Unexpected LLM response"):
|
||||
await llm_service.chat_completion(messages=[])
|
||||
|
||||
|
||||
class TestLLMServiceChatCompletionJson:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_json_parses_json(self, llm_service):
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
json_data={
|
||||
"choices": [{"message": {"content": '{"key": "value"}'}}],
|
||||
"usage": {},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await llm_service.chat_completion_json(
|
||||
system_prompt="You are helpful.",
|
||||
user_prompt="Return JSON.",
|
||||
)
|
||||
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_json_raises_on_non_json(self, llm_service):
|
||||
# First attempt: non-JSON; second attempt (retry): also non-JSON
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
json_data={
|
||||
"choices": [{"message": {"content": "not json at all"}}],
|
||||
"usage": {},
|
||||
},
|
||||
)
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMResponseError, match="could not be parsed as JSON"):
|
||||
await llm_service.chat_completion_json(
|
||||
system_prompt="test",
|
||||
user_prompt="test",
|
||||
)
|
||||
313
tests/services/test_post_processor.py
Normal file
313
tests/services/test_post_processor.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Tests for the PostProcessor (py/services/agent/post_processor.py).
|
||||
|
||||
PostProcessor delegates all I/O to AgentCLI — these tests mock AgentCLI
|
||||
functions and verify the business logic (conditions, merges, dispatch).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.agent.post_processor import PostProcessor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
return PostProcessor()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# process() — routing
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestProcessDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_skill_returns_error(self, processor):
|
||||
result = await processor.process(
|
||||
skill_name="nonexistent",
|
||||
model_path="/p.safetensors",
|
||||
llm_output={},
|
||||
metadata={},
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "nonexistent" in result["errors"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_hf_metadata_routes_correctly(self, processor):
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
mock_apply.return_value = ["metadata_source"]
|
||||
mock_dl.return_value = False
|
||||
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output={},
|
||||
metadata={"from_civitai": True},
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# enrich_hf_metadata — field-level logic
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestEnrichHfMetadata:
|
||||
"""Business logic tests for the enrich_hf_metadata post-processor."""
|
||||
|
||||
MIN_LLM_OUTPUT = {
|
||||
"base_model": "",
|
||||
"trigger_words": [],
|
||||
"description": "",
|
||||
"tags": [],
|
||||
"preview_url": "",
|
||||
"confidence": "low",
|
||||
}
|
||||
|
||||
# -- base_model ------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_overwrites_empty(self, processor):
|
||||
"""Empty current base_model → new value is applied."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["base_model"] == "Flux.1 D"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_does_not_overwrite_existing_civitai(self, processor):
|
||||
"""Existing base_model from CivitAI → not overwritten."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": "SDXL 1.0", "from_civitai": True},
|
||||
)
|
||||
# apply IS called (metadata_source, llm_enriched_at) but base_model not in it
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert "base_model" not in applied
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_overwrites_existing_hf_model(self, processor):
|
||||
"""Existing base_model from HF → overwritten (LLM is more reliable)."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": "SD 1.5", "from_civitai": False},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["base_model"] == "Flux.1 D"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_skipped_when_llm_empty(self, processor):
|
||||
"""LLM returns empty base_model → nothing written."""
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=self.MIN_LLM_OUTPUT,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert "base_model" not in applied
|
||||
|
||||
# -- trigger_words ---------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_words_merged(self, processor):
|
||||
"""New trigger words written when current list is empty."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"trainedWords": []},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["trainedWords"] == ["trigger1", "trigger2"]
|
||||
|
||||
# -- description -----------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_description_set_when_empty(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "description": "A model description"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"modelDescription": ""},
|
||||
)
|
||||
assert "modelDescription" in mock_apply.call_args[0][1]
|
||||
|
||||
# -- tags ------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_merged_and_deduplicated(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "tags": ["flux", "lora", "STYLE"]}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"tags": ["anime"], "from_civitai": False},
|
||||
)
|
||||
merged = mock_apply.call_args[0][1]["tags"]
|
||||
assert "anime" in merged
|
||||
assert "flux" in merged
|
||||
assert "style" in merged # lowercased
|
||||
# "lora" and "STYLE" → "lora" and "style"
|
||||
assert len(merged) == 4 # anime, flux, lora, style
|
||||
|
||||
# -- metadata_source & llm_enriched_at --------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_fields_always_set(self, processor):
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=self.MIN_LLM_OUTPUT,
|
||||
metadata={},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["metadata_source"] == "agent:enrich_hf_metadata"
|
||||
assert "llm_enriched_at" in applied
|
||||
|
||||
# -- preview download ------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_downloaded_when_url_provided(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "preview_url": "https://ex.com/img.png"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
mock_dl.return_value = True
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={},
|
||||
)
|
||||
assert result["preview_downloaded"] is True
|
||||
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_skipped_when_exists(self, processor):
|
||||
"""If current_preview file exists on disk, skip download."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "preview_url": "https://ex.com/img.png"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates"),
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
mock.patch("os.path.exists", return_value=True),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"preview_url": "/existing/preview.webp"},
|
||||
)
|
||||
mock_dl.assert_not_called()
|
||||
|
||||
# -- cache refresh ---------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refreshed_when_updates_applied(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates", return_value=["base_model"]),
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
mock_ref.assert_awaited_once_with("/p.safetensors")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_not_refreshed_when_nothing_changed(self, processor):
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates", return_value=[]),
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=self.MIN_LLM_OUTPUT,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
mock_ref.assert_not_called()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Unit: _merge_tags
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestMergeTags:
|
||||
def test_deduplicates_case_insensitive(self):
|
||||
existing = ["anime", "Flux"]
|
||||
new = ["flux", "LORA", "anime"]
|
||||
result = PostProcessor._merge_tags(existing, new)
|
||||
# All tags are lowercased (matching TagUpdateService behaviour)
|
||||
assert result == ["anime", "flux", "lora"]
|
||||
91
tests/services/test_skill_registry.py
Normal file
91
tests/services/test_skill_registry.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for the SkillRegistry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.agent.skill_registry import SkillRegistry
|
||||
from py.services.agent.skill_definition import SkillDefinition, SkillPermissions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
"""Create a SkillRegistry with the real skills directory."""
|
||||
SkillRegistry.reset_instance()
|
||||
reg = SkillRegistry()
|
||||
reg._discover()
|
||||
return reg
|
||||
|
||||
|
||||
class TestSkillRegistryDiscovery:
|
||||
def test_discovers_enrich_hf_metadata_skill(self, registry):
|
||||
skills = registry.list_skills()
|
||||
assert len(skills) >= 1
|
||||
skill = registry.get_skill("enrich_hf_metadata")
|
||||
assert skill is not None
|
||||
assert skill.name == "enrich_hf_metadata"
|
||||
assert skill.llm_required is True
|
||||
|
||||
def test_skill_has_correct_model_type_filter(self, registry):
|
||||
skill = registry.get_skill("enrich_hf_metadata")
|
||||
assert skill.model_type_filter == ["lora", "checkpoint", "embedding"]
|
||||
|
||||
def test_skill_has_permissions(self, registry):
|
||||
skill = registry.get_skill("enrich_hf_metadata")
|
||||
assert skill.permissions.write_metadata is True
|
||||
assert skill.permissions.write_previews is True
|
||||
assert "huggingface.co" in skill.permissions.network_domains
|
||||
|
||||
def test_get_skill_returns_none_for_unknown(self, registry):
|
||||
assert registry.get_skill("nonexistent_skill") is None
|
||||
|
||||
|
||||
class TestSkillRegistryLoading:
|
||||
def test_load_prompt_returns_content(self, registry):
|
||||
prompt = registry.load_prompt("enrich_hf_metadata")
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 100
|
||||
assert "base_model" in prompt
|
||||
assert "trigger_words" in prompt
|
||||
|
||||
def test_load_prompt_raises_for_unknown_skill(self, registry):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
registry.load_prompt("nonexistent")
|
||||
|
||||
def test_load_handler_raises_when_handler_missing(self, registry):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
registry.load_handler("enrich_hf_metadata")
|
||||
|
||||
|
||||
class TestSkillDefinition:
|
||||
def test_applies_to_model_type_with_filter(self):
|
||||
sd = SkillDefinition(
|
||||
name="test",
|
||||
title="Test",
|
||||
description="",
|
||||
llm_required=False,
|
||||
model_type_filter=["lora"],
|
||||
)
|
||||
assert sd.applies_to_model_type("lora") is True
|
||||
assert sd.applies_to_model_type("checkpoint") is False
|
||||
|
||||
def test_applies_to_model_type_without_filter(self):
|
||||
sd = SkillDefinition(
|
||||
name="test",
|
||||
title="Test",
|
||||
description="",
|
||||
llm_required=False,
|
||||
model_type_filter=None,
|
||||
)
|
||||
assert sd.applies_to_model_type("lora") is True
|
||||
assert sd.applies_to_model_type("checkpoint") is True
|
||||
|
||||
|
||||
class TestSkillPermissions:
|
||||
def test_defaults(self):
|
||||
sp = SkillPermissions()
|
||||
assert sp.write_metadata is True
|
||||
assert sp.write_previews is True
|
||||
assert sp.network_domains == ()
|
||||
Reference in New Issue
Block a user