mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat(metadata): implement model data hydration and enhance metadata handling across services, fixes #547
This commit is contained in:
@@ -30,6 +30,7 @@ from ...services.use_cases import (
|
|||||||
from ...services.websocket_manager import WebSocketManager
|
from ...services.websocket_manager import WebSocketManager
|
||||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||||
from ...utils.file_utils import calculate_sha256
|
from ...utils.file_utils import calculate_sha256
|
||||||
|
from ...utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
class ModelPageView:
|
class ModelPageView:
|
||||||
@@ -244,6 +245,8 @@ class ModelManagementHandler:
|
|||||||
if not model_data.get("sha256"):
|
if not model_data.get("sha256"):
|
||||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||||
|
|
||||||
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
|
|
||||||
success, error = await self._metadata_sync.fetch_and_update_model(
|
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||||
sha256=model_data["sha256"],
|
sha256=model_data["sha256"],
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import logging
|
|||||||
from typing import Any, Dict, Optional, Protocol, Sequence
|
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||||
|
|
||||||
from ..metadata_sync_service import MetadataSyncService
|
from ..metadata_sync_service import MetadataSyncService
|
||||||
|
from ...utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
class MetadataRefreshProgressReporter(Protocol):
|
class MetadataRefreshProgressReporter(Protocol):
|
||||||
@@ -70,6 +71,7 @@ class BulkMetadataRefreshUseCase:
|
|||||||
for model in to_process:
|
for model in to_process:
|
||||||
try:
|
try:
|
||||||
original_name = model.get("model_name")
|
original_name = model.get("model_name")
|
||||||
|
await MetadataManager.hydrate_model_data(model)
|
||||||
result, _ = await self._metadata_sync.fetch_and_update_model(
|
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||||
sha256=model["sha256"],
|
sha256=model["sha256"],
|
||||||
file_path=model["file_path"],
|
file_path=model["file_path"],
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from ..recipes.constants import GEN_PARAM_KEYS
|
from ..recipes.constants import GEN_PARAM_KEYS
|
||||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
@@ -105,6 +105,7 @@ class MetadataUpdater:
|
|||||||
async def update_cache_func(old_path, new_path, metadata):
|
async def update_cache_func(old_path, new_path, metadata):
|
||||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||||
|
|
||||||
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
success, error = await _get_metadata_sync_service().fetch_and_update_model(
|
success, error = await _get_metadata_sync_service().fetch_and_update_model(
|
||||||
sha256=model_hash,
|
sha256=model_hash,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
@@ -185,16 +186,16 @@ class MetadataUpdater:
|
|||||||
if is_supported:
|
if is_supported:
|
||||||
local_images_paths.append(file_path)
|
local_images_paths.append(file_path)
|
||||||
|
|
||||||
|
await MetadataManager.hydrate_model_data(model)
|
||||||
|
civitai_data = model.setdefault('civitai', {})
|
||||||
|
|
||||||
# Check if metadata update is needed (no civitai field or empty images)
|
# Check if metadata update is needed (no civitai field or empty images)
|
||||||
needs_update = not model.get('civitai') or not model.get('civitai', {}).get('images')
|
needs_update = not civitai_data or not civitai_data.get('images')
|
||||||
|
|
||||||
if needs_update and local_images_paths:
|
if needs_update and local_images_paths:
|
||||||
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
|
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
|
||||||
|
|
||||||
# Create or get civitai field
|
# Create or get civitai field
|
||||||
if not model.get('civitai'):
|
|
||||||
model['civitai'] = {}
|
|
||||||
|
|
||||||
# Create images array
|
# Create images array
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
@@ -229,16 +230,13 @@ class MetadataUpdater:
|
|||||||
images.append(image_entry)
|
images.append(image_entry)
|
||||||
|
|
||||||
# Update the model's civitai.images field
|
# Update the model's civitai.images field
|
||||||
model['civitai']['images'] = images
|
civitai_data['images'] = images
|
||||||
|
|
||||||
# Save metadata to .metadata.json file
|
# Save metadata to .metadata.json file
|
||||||
file_path = model.get('file_path')
|
file_path = model.get('file_path')
|
||||||
try:
|
try:
|
||||||
# Create a copy of model data without 'folder' field
|
|
||||||
model_copy = model.copy()
|
model_copy = model.copy()
|
||||||
model_copy.pop('folder', None)
|
model_copy.pop('folder', None)
|
||||||
|
|
||||||
# Write metadata to file
|
|
||||||
await MetadataManager.save_metadata(file_path, model_copy)
|
await MetadataManager.save_metadata(file_path, model_copy)
|
||||||
logger.info(f"Saved metadata for {model.get('model_name')}")
|
logger.info(f"Saved metadata for {model.get('model_name')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -271,16 +269,13 @@ class MetadataUpdater:
|
|||||||
tuple: (regular_images, custom_images) - Both image arrays
|
tuple: (regular_images, custom_images) - Both image arrays
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Ensure civitai field exists in model_data
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
if not model_data.get('civitai'):
|
civitai_data = model_data.setdefault('civitai', {})
|
||||||
model_data['civitai'] = {}
|
custom_images = civitai_data.get('customImages')
|
||||||
|
|
||||||
# Ensure customImages array exists
|
if not isinstance(custom_images, list):
|
||||||
if not model_data['civitai'].get('customImages'):
|
custom_images = []
|
||||||
model_data['civitai']['customImages'] = []
|
civitai_data['customImages'] = custom_images
|
||||||
|
|
||||||
# Get current customImages array
|
|
||||||
custom_images = model_data['civitai']['customImages']
|
|
||||||
|
|
||||||
# Add new image entry for each imported file
|
# Add new image entry for each imported file
|
||||||
for path_tuple in newly_imported_paths:
|
for path_tuple in newly_imported_paths:
|
||||||
@@ -338,11 +333,8 @@ class MetadataUpdater:
|
|||||||
file_path = model_data.get('file_path')
|
file_path = model_data.get('file_path')
|
||||||
if file_path:
|
if file_path:
|
||||||
try:
|
try:
|
||||||
# Create a copy of model data without 'folder' field
|
|
||||||
model_copy = model_data.copy()
|
model_copy = model_data.copy()
|
||||||
model_copy.pop('folder', None)
|
model_copy.pop('folder', None)
|
||||||
|
|
||||||
# Write metadata to file
|
|
||||||
await MetadataManager.save_metadata(file_path, model_copy)
|
await MetadataManager.save_metadata(file_path, model_copy)
|
||||||
logger.info(f"Saved metadata for {model_data.get('model_name')}")
|
logger.info(f"Saved metadata for {model_data.get('model_name')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -353,7 +345,7 @@ class MetadataUpdater:
|
|||||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||||
|
|
||||||
# Get regular images array (might be None)
|
# Get regular images array (might be None)
|
||||||
regular_images = model_data['civitai'].get('images', [])
|
regular_images = civitai_data.get('images', [])
|
||||||
|
|
||||||
# Return both image arrays
|
# Return both image arrays
|
||||||
return regular_images, custom_images
|
return regular_images, custom_images
|
||||||
|
|||||||
@@ -475,15 +475,17 @@ class ExampleImagesProcessor:
|
|||||||
'error': f"Model with hash {model_hash} not found in cache"
|
'error': f"Model with hash {model_hash} not found in cache"
|
||||||
}, status=404)
|
}, status=404)
|
||||||
|
|
||||||
# Check if model has custom images
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
if not model_data.get('civitai', {}).get('customImages'):
|
civitai_data = model_data.setdefault('civitai', {})
|
||||||
|
custom_images = civitai_data.get('customImages')
|
||||||
|
|
||||||
|
if not isinstance(custom_images, list) or not custom_images:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
'error': f"Model has no custom images"
|
'error': f"Model has no custom images"
|
||||||
}, status=404)
|
}, status=404)
|
||||||
|
|
||||||
# Find the custom image with matching short_id
|
# Find the custom image with matching short_id
|
||||||
custom_images = model_data['civitai']['customImages']
|
|
||||||
matching_image = None
|
matching_image = None
|
||||||
new_custom_images = []
|
new_custom_images = []
|
||||||
|
|
||||||
@@ -527,17 +529,15 @@ class ExampleImagesProcessor:
|
|||||||
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
|
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
|
||||||
|
|
||||||
# Update metadata
|
# Update metadata
|
||||||
model_data['civitai']['customImages'] = new_custom_images
|
civitai_data['customImages'] = new_custom_images
|
||||||
|
model_data.setdefault('civitai', {})['customImages'] = new_custom_images
|
||||||
|
|
||||||
# Save updated metadata to file
|
# Save updated metadata to file
|
||||||
file_path = model_data.get('file_path')
|
file_path = model_data.get('file_path')
|
||||||
if file_path:
|
if file_path:
|
||||||
try:
|
try:
|
||||||
# Create a copy of model data without 'folder' field
|
|
||||||
model_copy = model_data.copy()
|
model_copy = model_data.copy()
|
||||||
model_copy.pop('folder', None)
|
model_copy.pop('folder', None)
|
||||||
|
|
||||||
# Write metadata to file
|
|
||||||
await MetadataManager.save_metadata(file_path, model_copy)
|
await MetadataManager.save_metadata(file_path, model_copy)
|
||||||
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
|
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -551,7 +551,7 @@ class ExampleImagesProcessor:
|
|||||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||||
|
|
||||||
# Get regular images array (might be None)
|
# Get regular images array (might be None)
|
||||||
regular_images = model_data['civitai'].get('images', [])
|
regular_images = civitai_data.get('images', [])
|
||||||
|
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -568,4 +568,4 @@ class ExampleImagesProcessor:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
from .models import BaseModelMetadata, LoraMetadata
|
from .models import BaseModelMetadata, LoraMetadata
|
||||||
from .file_utils import normalize_path, find_preview_file, calculate_sha256
|
from .file_utils import normalize_path, find_preview_file, calculate_sha256
|
||||||
@@ -53,6 +53,70 @@ class MetadataManager:
|
|||||||
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
|
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
|
||||||
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
|
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
|
||||||
return None, True # should_skip = True
|
return None, True # should_skip = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def load_metadata_payload(file_path: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Load metadata and return it as a dictionary, including any unknown fields.
|
||||||
|
Falls back to reading the raw JSON file if parsing into a model class fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload: Dict = {}
|
||||||
|
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
|
||||||
|
|
||||||
|
if metadata_obj:
|
||||||
|
payload = metadata_obj.to_dict()
|
||||||
|
unknown_fields = getattr(metadata_obj, "_unknown_fields", None)
|
||||||
|
if isinstance(unknown_fields, dict):
|
||||||
|
payload.update(unknown_fields)
|
||||||
|
else:
|
||||||
|
if not should_skip:
|
||||||
|
metadata_path = (
|
||||||
|
file_path
|
||||||
|
if file_path.endswith(".metadata.json")
|
||||||
|
else f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||||
|
)
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||||
|
raw = json.load(handle)
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
payload = raw
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse metadata file %s while loading payload",
|
||||||
|
metadata_path,
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.warning("Failed to read metadata file %s: %s", metadata_path, exc)
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
payload = {}
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
payload.setdefault("file_path", normalize_path(file_path))
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def hydrate_model_data(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Replace the provided model data with the authoritative payload from disk.
|
||||||
|
Preserves the cached folder entry if present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_path = model_data.get("file_path")
|
||||||
|
if not file_path:
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
folder = model_data.get("folder")
|
||||||
|
payload = await MetadataManager.load_metadata_payload(file_path)
|
||||||
|
if folder is not None:
|
||||||
|
payload["folder"] = folder
|
||||||
|
|
||||||
|
model_data.clear()
|
||||||
|
model_data.update(payload)
|
||||||
|
return model_data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
|
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ from aiohttp.test_utils import TestClient, TestServer
|
|||||||
from py.config import config
|
from py.config import config
|
||||||
from py.routes.base_model_routes import BaseModelRoutes
|
from py.routes.base_model_routes import BaseModelRoutes
|
||||||
from py.services import model_file_service
|
from py.services import model_file_service
|
||||||
|
from py.services.metadata_sync_service import MetadataSyncService
|
||||||
from py.services.model_file_service import AutoOrganizeResult
|
from py.services.model_file_service import AutoOrganizeResult
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.websocket_manager import ws_manager
|
from py.services.websocket_manager import ws_manager
|
||||||
from py.utils.exif_utils import ExifUtils
|
from py.utils.exif_utils import ExifUtils
|
||||||
|
from py.utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
class DummyRoutes(BaseModelRoutes):
|
class DummyRoutes(BaseModelRoutes):
|
||||||
@@ -197,6 +199,116 @@ def test_replace_preview_writes_file_and_updates_cache(
|
|||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_civitai_hydrates_metadata_before_sync(
|
||||||
|
mock_service,
|
||||||
|
mock_scanner,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
model_path = tmp_path / "hydrate.safetensors"
|
||||||
|
model_path.write_bytes(b"model")
|
||||||
|
metadata_path = tmp_path / "hydrate.metadata.json"
|
||||||
|
|
||||||
|
existing_metadata = {
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"sha256": "abc123",
|
||||||
|
"model_name": "Hydrated",
|
||||||
|
"preview_url": "keep/me.png",
|
||||||
|
"civitai": {
|
||||||
|
"id": 99,
|
||||||
|
"modelId": 42,
|
||||||
|
"images": [{"url": "https://example.com/existing.png", "type": "image"}],
|
||||||
|
"customImages": [{"id": "old-id", "url": "", "type": "image"}],
|
||||||
|
"trainedWords": ["keep"],
|
||||||
|
},
|
||||||
|
"custom_field": "preserve",
|
||||||
|
}
|
||||||
|
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
|
||||||
|
|
||||||
|
minimal_cache_entry = {
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"sha256": "abc123",
|
||||||
|
"folder": "some/folder",
|
||||||
|
"civitai": {"id": 99, "modelId": 42},
|
||||||
|
}
|
||||||
|
mock_scanner._cache.raw_data = [minimal_cache_entry]
|
||||||
|
|
||||||
|
class FakeMetadata:
|
||||||
|
def __init__(self, payload: dict) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
self._unknown_fields = {"legacy_field": "legacy"}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return self._payload.copy()
|
||||||
|
|
||||||
|
async def fake_load_metadata(path: str, *_args, **_kwargs):
|
||||||
|
assert path == str(model_path)
|
||||||
|
return FakeMetadata(existing_metadata), False
|
||||||
|
|
||||||
|
async def fake_save_metadata(path: str, metadata: dict) -> bool:
|
||||||
|
save_calls.append((path, json.loads(json.dumps(metadata))))
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def fake_fetch_and_update_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sha256: str,
|
||||||
|
file_path: str,
|
||||||
|
model_data: dict,
|
||||||
|
update_cache_func,
|
||||||
|
):
|
||||||
|
captured["model_data"] = json.loads(json.dumps(model_data))
|
||||||
|
to_save = model_data.copy()
|
||||||
|
to_save.pop("folder", None)
|
||||||
|
await MetadataManager.save_metadata(
|
||||||
|
os.path.splitext(file_path)[0] + ".metadata.json",
|
||||||
|
to_save,
|
||||||
|
)
|
||||||
|
await update_cache_func(file_path, file_path, model_data)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
save_calls: list[tuple[str, dict]] = []
|
||||||
|
captured: dict[str, dict] = {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
|
||||||
|
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
client = await create_test_client(mock_service)
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/lm/test-models/fetch-civitai",
|
||||||
|
json={"file_path": str(model_path)},
|
||||||
|
)
|
||||||
|
payload = await response.json()
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert captured["model_data"]["custom_field"] == "preserve"
|
||||||
|
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
||||||
|
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
|
||||||
|
assert captured["model_data"]["civitai"]["id"] == 99
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
assert save_calls, "Metadata save should be invoked"
|
||||||
|
saved_path, saved_payload = save_calls[0]
|
||||||
|
assert saved_path == str(metadata_path)
|
||||||
|
assert saved_payload["custom_field"] == "preserve"
|
||||||
|
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
||||||
|
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
|
||||||
|
assert saved_payload["civitai"]["id"] == 99
|
||||||
|
assert saved_payload["legacy_field"] == "legacy"
|
||||||
|
|
||||||
|
assert mock_scanner.updated_models
|
||||||
|
updated_metadata = mock_scanner.updated_models[-1]["metadata"]
|
||||||
|
assert updated_metadata["custom_field"] == "preserve"
|
||||||
|
assert updated_metadata["civitai"]["customImages"][0]["id"] == "old-id"
|
||||||
|
|
||||||
|
|
||||||
def test_download_model_invokes_download_manager(
|
def test_download_model_invokes_download_manager(
|
||||||
mock_service,
|
mock_service,
|
||||||
download_manager_stub,
|
download_manager_stub,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from py.utils.example_images_processor import (
|
|||||||
ExampleImagesImportError,
|
ExampleImagesImportError,
|
||||||
ExampleImagesValidationError,
|
ExampleImagesValidationError,
|
||||||
)
|
)
|
||||||
|
from py.utils.metadata_manager import MetadataManager
|
||||||
from tests.conftest import MockModelService, MockScanner
|
from tests.conftest import MockModelService, MockScanner
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +156,9 @@ async def test_auto_organize_use_case_rejects_when_running() -> None:
|
|||||||
await use_case.execute(file_paths=None, progress_callback=None)
|
await use_case.execute(file_paths=None, progress_callback=None)
|
||||||
|
|
||||||
|
|
||||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
scanner = MockScanner()
|
scanner = MockScanner()
|
||||||
scanner._cache.raw_data = [
|
scanner._cache.raw_data = [
|
||||||
{
|
{
|
||||||
@@ -170,6 +173,25 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
|||||||
settings = StubSettings()
|
settings = StubSettings()
|
||||||
progress = ProgressCollector()
|
progress = ProgressCollector()
|
||||||
|
|
||||||
|
hydration_calls: list[str] = []
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
hydration_calls.append(model_data.get("file_path", ""))
|
||||||
|
model_data.clear()
|
||||||
|
model_data.update(
|
||||||
|
{
|
||||||
|
"file_path": "model1.safetensors",
|
||||||
|
"sha256": "hash",
|
||||||
|
"from_civitai": True,
|
||||||
|
"model_name": "Demo",
|
||||||
|
"extra": "value",
|
||||||
|
"civitai": {"images": [{"url": "existing.png", "type": "image"}]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||||
|
|
||||||
use_case = BulkMetadataRefreshUseCase(
|
use_case = BulkMetadataRefreshUseCase(
|
||||||
service=service,
|
service=service,
|
||||||
metadata_sync=metadata_sync,
|
metadata_sync=metadata_sync,
|
||||||
@@ -183,6 +205,9 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
|||||||
assert progress.events[0]["status"] == "started"
|
assert progress.events[0]["status"] == "started"
|
||||||
assert progress.events[-1]["status"] == "completed"
|
assert progress.events[-1]["status"] == "completed"
|
||||||
assert metadata_sync.calls
|
assert metadata_sync.calls
|
||||||
|
assert metadata_sync.calls[0]["model_data"]["extra"] == "value"
|
||||||
|
assert scanner._cache.raw_data[0]["extra"] == "value"
|
||||||
|
assert hydration_calls == ["model1.safetensors"]
|
||||||
assert scanner._cache.resort_calls == 1
|
assert scanner._cache.resort_calls == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -314,4 +339,4 @@ async def test_import_example_images_use_case_propagates_generic_error() -> None
|
|||||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||||
|
|
||||||
with pytest.raises(ExampleImagesImportError):
|
with pytest.raises(ExampleImagesImportError):
|
||||||
await use_case.execute(request)
|
await use_case.execute(request)
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
@@ -30,7 +32,23 @@ def patch_metadata_manager(monkeypatch: pytest.MonkeyPatch):
|
|||||||
saved.append((path, metadata.copy()))
|
saved.append((path, metadata.copy()))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class SimpleMetadata:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
self._unknown_fields: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._payload.copy()
|
||||||
|
|
||||||
|
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
|
||||||
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
|
||||||
|
return SimpleMetadata(data), False
|
||||||
|
return None, False
|
||||||
|
|
||||||
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||||
|
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|
||||||
@@ -64,10 +82,80 @@ async def test_update_metadata_after_import_enriches_entries(monkeypatch: pytest
|
|||||||
assert custom[0]["hasMeta"] is True
|
assert custom[0]["hasMeta"] is True
|
||||||
assert custom[0]["type"] == "image"
|
assert custom[0]["type"] == "image"
|
||||||
|
|
||||||
assert patch_metadata_manager[0][0] == str(model_file)
|
assert Path(patch_metadata_manager[0][0]) == model_file
|
||||||
assert scanner.updates
|
assert scanner.updates
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_metadata_after_import_preserves_existing_metadata(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
patch_metadata_manager,
|
||||||
|
):
|
||||||
|
model_hash = "b" * 64
|
||||||
|
model_file = tmp_path / "preserve.safetensors"
|
||||||
|
model_file.write_text("content", encoding="utf-8")
|
||||||
|
metadata_path = tmp_path / "preserve.metadata.json"
|
||||||
|
|
||||||
|
existing_payload = {
|
||||||
|
"model_name": "Example",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"id": 42,
|
||||||
|
"modelId": 88,
|
||||||
|
"name": "Example",
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
"images": [{"url": "https://example.com/default.png", "type": "image"}],
|
||||||
|
"customImages": [
|
||||||
|
{"id": "existing-id", "type": "image", "url": "", "nsfwLevel": 0}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"extraField": "keep-me",
|
||||||
|
}
|
||||||
|
metadata_path.write_text(json.dumps(existing_payload), encoding="utf-8")
|
||||||
|
|
||||||
|
model_data = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Example",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"id": 42,
|
||||||
|
"modelId": 88,
|
||||||
|
"name": "Example",
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
"customImages": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scanner = StubScanner([model_data])
|
||||||
|
|
||||||
|
image_path = tmp_path / "new.png"
|
||||||
|
image_path.write_bytes(b"fakepng")
|
||||||
|
|
||||||
|
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None))
|
||||||
|
monkeypatch.setattr(metadata_module.MetadataUpdater, "_parse_image_metadata", staticmethod(lambda payload: None))
|
||||||
|
|
||||||
|
regular, custom = await metadata_module.MetadataUpdater.update_metadata_after_import(
|
||||||
|
model_hash,
|
||||||
|
model_data,
|
||||||
|
scanner,
|
||||||
|
[(str(image_path), "new-id")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert regular == existing_payload["civitai"]["images"]
|
||||||
|
assert any(entry["id"] == "new-id" for entry in custom)
|
||||||
|
|
||||||
|
saved_path, saved_payload = patch_metadata_manager[-1]
|
||||||
|
assert Path(saved_path) == model_file
|
||||||
|
assert saved_payload["extraField"] == "keep-me"
|
||||||
|
assert saved_payload["civitai"]["images"] == existing_payload["civitai"]["images"]
|
||||||
|
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
|
||||||
|
assert {entry["id"] for entry in saved_payload["civitai"]["customImages"]} == {"existing-id", "new-id"}
|
||||||
|
|
||||||
|
assert scanner.updates
|
||||||
|
updated_metadata = scanner.updates[-1][2]
|
||||||
|
assert updated_metadata["civitai"]["images"] == existing_payload["civitai"]["images"]
|
||||||
|
assert {entry["id"] for entry in updated_metadata["civitai"]["customImages"]} == {"existing-id", "new-id"}
|
||||||
|
|
||||||
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
model_hash = "b" * 64
|
model_hash = "b" * 64
|
||||||
model_file = tmp_path / "model.safetensors"
|
model_file = tmp_path / "model.safetensors"
|
||||||
@@ -79,6 +167,16 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
|
|||||||
async def fetch_and_update_model(self, **_kwargs):
|
async def fetch_and_update_model(self, **_kwargs):
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
model_data["hydrated"] = True
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
metadata_module.MetadataManager,
|
||||||
|
"hydrate_model_data",
|
||||||
|
staticmethod(fake_hydrate),
|
||||||
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
|
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
|
||||||
|
|
||||||
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
|
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
|
||||||
@@ -89,6 +187,7 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
|
|||||||
{"refreshed_models": set(), "errors": [], "last_error": None},
|
{"refreshed_models": set(), "errors": [], "last_error": None},
|
||||||
)
|
)
|
||||||
assert result is True
|
assert result is True
|
||||||
|
assert cache_item["hydrated"] is True
|
||||||
|
|
||||||
|
|
||||||
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
@@ -112,4 +211,4 @@ async def test_update_metadata_from_local_examples_generates_entries(monkeypatch
|
|||||||
str(model_dir),
|
str(model_dir),
|
||||||
)
|
)
|
||||||
assert success is True
|
assert success is True
|
||||||
assert model_data["civitai"]["images"]
|
assert model_data["civitai"]["images"]
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -8,7 +9,9 @@ from typing import Any, Dict, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import get_settings_manager
|
from py.services.settings_manager import get_settings_manager
|
||||||
|
from py.utils import example_images_metadata as metadata_module
|
||||||
from py.utils import example_images_processor as processor_module
|
from py.utils import example_images_processor as processor_module
|
||||||
|
from py.utils.example_images_paths import get_model_folder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -22,6 +25,27 @@ def restore_settings() -> None:
|
|||||||
manager.settings.update(original)
|
manager.settings.update(original)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_metadata_loader(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
class SimpleMetadata:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
self._unknown_fields: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._payload.copy()
|
||||||
|
|
||||||
|
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
|
||||||
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
|
||||||
|
return SimpleMetadata(data), False
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
monkeypatch.setattr(processor_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||||
|
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||||
|
|
||||||
|
|
||||||
def test_get_file_extension_from_magic_bytes() -> None:
|
def test_get_file_extension_from_magic_bytes() -> None:
|
||||||
jpg_bytes = b"\xff\xd8\xff" + b"rest"
|
jpg_bytes = b"\xff\xd8\xff" + b"rest"
|
||||||
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||||
@@ -146,3 +170,88 @@ async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.Mon
|
|||||||
|
|
||||||
with pytest.raises(processor_module.ExampleImagesImportError):
|
with pytest.raises(processor_module.ExampleImagesImportError):
|
||||||
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
|
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_custom_image_preserves_existing_metadata(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
|
||||||
|
|
||||||
|
model_hash = "c" * 64
|
||||||
|
model_file = tmp_path / "keep.safetensors"
|
||||||
|
model_file.write_text("content", encoding="utf-8")
|
||||||
|
metadata_path = tmp_path / "keep.metadata.json"
|
||||||
|
|
||||||
|
existing_metadata = {
|
||||||
|
"model_name": "Keep",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"images": [{"url": "https://example.com/default.png", "type": "image"}],
|
||||||
|
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
|
||||||
|
|
||||||
|
model_data = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Keep",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
class Scanner(StubScanner):
|
||||||
|
def has_hash(self, hash_value: str) -> bool:
|
||||||
|
return hash_value == model_hash
|
||||||
|
|
||||||
|
scanner = Scanner([model_data])
|
||||||
|
|
||||||
|
async def _return_scanner(cls=None):
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_return_scanner))
|
||||||
|
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_return_scanner))
|
||||||
|
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_return_scanner))
|
||||||
|
|
||||||
|
model_folder = get_model_folder(model_hash)
|
||||||
|
os.makedirs(model_folder, exist_ok=True)
|
||||||
|
(Path(model_folder) / "custom_existing-id.png").write_bytes(b"data")
|
||||||
|
|
||||||
|
saved: list[tuple[str, Dict[str, Any]]] = []
|
||||||
|
|
||||||
|
async def fake_save(path: str, payload: Dict[str, Any]) -> bool:
|
||||||
|
saved.append((path, payload.copy()))
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr(processor_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||||
|
|
||||||
|
class StubRequest:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
|
||||||
|
async def json(self) -> Dict[str, Any]:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
response = await processor_module.ExampleImagesProcessor.delete_custom_image(
|
||||||
|
StubRequest({"model_hash": model_hash, "short_id": "existing-id"})
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = json.loads(response.text)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["custom_images"] == []
|
||||||
|
assert not (Path(model_folder) / "custom_existing-id.png").exists()
|
||||||
|
|
||||||
|
saved_path, saved_payload = saved[-1]
|
||||||
|
assert saved_path == str(model_file)
|
||||||
|
assert saved_payload["civitai"]["images"] == existing_metadata["civitai"]["images"]
|
||||||
|
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
|
||||||
|
assert saved_payload["civitai"]["customImages"] == []
|
||||||
|
|
||||||
|
assert scanner.updated
|
||||||
|
_, _, updated_metadata = scanner.updated[-1]
|
||||||
|
assert updated_metadata["civitai"]["images"] == existing_metadata["civitai"]["images"]
|
||||||
|
assert updated_metadata["civitai"]["customImages"] == []
|
||||||
|
|||||||
Reference in New Issue
Block a user