From 3c0feb23ba0b4fe8441aa8fced22c97abf1a8108 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 16 Oct 2025 07:01:04 +0800 Subject: [PATCH] feat(model-cache): respect model name display preference --- py/services/model_cache.py | 44 +++++++++++++++++++- py/services/model_scanner.py | 37 ++++++++++++++--- py/services/settings_manager.py | 53 +++++++++++++++++++++++++ tests/services/test_model_cache.py | 41 +++++++++++++++++++ tests/services/test_settings_manager.py | 33 +++++++++++++++ 5 files changed, 201 insertions(+), 7 deletions(-) create mode 100644 tests/services/test_model_cache.py diff --git a/py/services/model_cache.py b/py/services/model_cache.py index e3c94cee..b5ecc47f 100644 --- a/py/services/model_cache.py +++ b/py/services/model_cache.py @@ -15,6 +15,9 @@ SUPPORTED_SORT_MODES = [ ('size', 'desc'), ] +DISPLAY_NAME_MODES = {"model_name", "file_name"} + + @dataclass class ModelCache: """Cache structure for model data with extensible sorting.""" @@ -22,16 +25,39 @@ class ModelCache: raw_data: List[Dict] folders: List[str] version_index: Dict[int, Dict] = field(default_factory=dict) + name_display_mode: str = "model_name" def __post_init__(self): self._lock = asyncio.Lock() # Cache for last sort: (sort_key, order) -> sorted list self._last_sort: Tuple[str, str] = (None, None) self._last_sorted_data: List[Dict] = [] + self.name_display_mode = self._normalize_display_mode(self.name_display_mode) # Default sort on init asyncio.create_task(self.resort()) self.rebuild_version_index() + @staticmethod + def _normalize_display_mode(value: Optional[str]) -> str: + if isinstance(value, str) and value in DISPLAY_NAME_MODES: + return value + return "model_name" + + def _get_display_name(self, item: Dict) -> str: + """Return the value used for name-based sorting based on display settings.""" + + if self.name_display_mode == "file_name": + primary = item.get("file_name", "") + fallback = item.get("model_name", "") + else: + primary = item.get("model_name", "") + fallback = item.get("file_name", "") + + candidate = primary or fallback or "" + if isinstance(candidate, str): + return candidate + return str(candidate) + @staticmethod def _normalize_version_id(value: Any) -> Optional[int]: """Normalize a potential version identifier into an integer.""" @@ -101,10 +127,10 @@ class ModelCache: """Sort data by sort_key and order""" reverse = (order == 'desc') if sort_key == 'name': - # Natural sort by model_name, case-insensitive + # Natural sort by configured display name, case-insensitive return natsorted( data, - key=lambda x: x['model_name'].lower(), + key=lambda x: self._get_display_name(x).lower(), reverse=reverse ) elif sort_key == 'date': @@ -135,6 +161,20 @@ class ModelCache: self._last_sorted_data = sorted_data return sorted_data + async def update_name_display_mode(self, display_mode: str) -> None: + """Update the display mode used for name sorting and refresh cached results.""" + + normalized = self._normalize_display_mode(display_mode) + async with self._lock: + if self.name_display_mode == normalized: + return + + self.name_display_mode = normalized + + if self._last_sort[0] == 'name': + sort_key, order = self._last_sort + self._last_sorted_data = self._sort_data(self.raw_data, sort_key, order) + async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool: """Update preview_url for a specific model in all cached data diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index eedd953b..5c7389a9 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -18,6 +18,7 @@ from .model_lifecycle_service import delete_model_artifacts from .service_registry import ServiceRegistry from .websocket_manager import ws_manager from .persistent_model_cache import get_persistent_cache +from .settings_manager import get_settings_manager logger = logging.getLogger(__name__) @@ -81,6 +82,7 @@ class ModelScanner: self._is_initializing = False # Flag to track initialization state self._excluded_models = [] # List to track excluded models self._persistent_cache = get_persistent_cache() + self._name_display_mode = self._resolve_name_display_mode() self._initialized = True # Register this service @@ -94,6 +96,7 @@ class ModelScanner: self._tags_count = {} self._excluded_models = [] self._is_initializing = False + self._name_display_mode = self._resolve_name_display_mode() try: loop = asyncio.get_running_loop() @@ -102,7 +105,27 @@ class ModelScanner: if loop and not loop.is_closed(): loop.create_task(self.initialize_in_background()) - + + def _resolve_name_display_mode(self) -> str: + """Return the configured display mode for name sorting.""" + + try: + manager = get_settings_manager() + except Exception: # pragma: no cover - fallback to defaults + return "model_name" + + value = manager.get("model_name_display", "model_name") + return ModelCache._normalize_display_mode(value) + + async def on_model_name_display_changed(self, display_mode: str) -> None: + """Handle updates to the model name display preference.""" + + normalized = ModelCache._normalize_display_mode(display_mode) + self._name_display_mode = normalized + + if self._cache is not None: + await self._cache.update_name_display_mode(normalized) + async def _register_service(self): """Register this instance with the ServiceRegistry""" service_name = f"{self.model_type}_scanner" @@ -211,7 +234,8 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=[], - folders=[] + folders=[], + name_display_mode=self._name_display_mode, ) # Set initializing flag to true @@ -516,7 +540,8 @@ class ModelScanner: if self._cache is None and not force_refresh: return ModelCache( raw_data=[], - folders=[] + folders=[], + name_display_mode=self._name_display_mode, ) # If force refresh is requested, initialize the cache directly @@ -549,7 +574,8 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=[], - folders=[] + folders=[], + name_display_mode=self._name_display_mode, ) finally: self._is_initializing = False # Unset flag @@ -837,7 +863,8 @@ class ModelScanner: if self._cache is None: self._cache = ModelCache( raw_data=list(scan_result.raw_data), - folders=[] + folders=[], + name_display_mode=self._name_display_mode, ) else: self._cache.raw_data = list(scan_result.raw_data) diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 13d6ea0a..3e0475f0 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -1,3 +1,4 @@ +import asyncio import copy import json import os @@ -465,6 +466,8 @@ class SettingsManager: self._update_active_library_entry(default_checkpoint_root=str(value)) elif key == 'default_embedding_root': self._update_active_library_entry(default_embedding_root=str(value)) + elif key == 'model_name_display': + self._notify_model_name_display_change(value) self._save_settings() def delete(self, key: str) -> None: @@ -474,6 +477,56 @@ class SettingsManager: self._save_settings() logger.info(f"Deleted setting: {key}") + def _notify_model_name_display_change(self, value: Any) -> None: + """Trigger cache resorting when the model name display preference updates.""" + + try: + from .service_registry import ServiceRegistry # type: ignore + except Exception: # pragma: no cover - registry optional in some contexts + return + + display_mode = value if isinstance(value, str) else "model_name" + coroutines = [] + + for service_name in ( + "lora_scanner", + "checkpoint_scanner", + "embedding_scanner", + "recipe_scanner", + ): + service = ServiceRegistry.get_service_sync(service_name) + if not service or not hasattr(service, "on_model_name_display_changed"): + continue + + try: + result = service.on_model_name_display_changed(display_mode) + except Exception as exc: # pragma: no cover - defensive guard + logger.debug( + "Service %s failed to schedule name display update: %s", + service_name, + exc, + ) + continue + + if asyncio.iscoroutine(result): + coroutines.append(result) + + if not coroutines: + return + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + for coroutine in coroutines: + try: + asyncio.run(coroutine) + except RuntimeError: + # If event loop is already running in another thread, skip execution + logger.debug("Skipping name display update due to running loop") + else: + for coroutine in coroutines: + loop.create_task(coroutine) + def _save_settings(self) -> None: """Save settings to file""" try: diff --git a/tests/services/test_model_cache.py b/tests/services/test_model_cache.py new file mode 100644 index 00000000..cc06ca21 --- /dev/null +++ b/tests/services/test_model_cache.py @@ -0,0 +1,41 @@ +import pytest + +from py.services.model_cache import ModelCache + + +@pytest.mark.asyncio +async def test_name_sort_respects_file_name_display(): + items = [ + {"model_name": "Bravo", "file_name": "zulu", "folder": "", "size": 1, "modified": 1}, + {"model_name": "Alpha", "file_name": "alpha", "folder": "", "size": 1, "modified": 1}, + {"model_name": "Charlie", "file_name": "echo", "folder": "", "size": 1, "modified": 1}, + ] + + cache = ModelCache(raw_data=items, folders=[], name_display_mode="file_name") + + sorted_items = await cache.get_sorted_data("name", "asc") + + assert [item["file_name"] for item in sorted_items] == [ + "alpha", + "echo", + "zulu", + ] + + +@pytest.mark.asyncio +async def test_update_name_display_mode_resorts_cached_name_order(): + items = [ + {"model_name": "Zulu", "file_name": "alpha", "folder": "", "size": 1, "modified": 1}, + {"model_name": "Alpha", "file_name": "zulu", "folder": "", "size": 1, "modified": 1}, + ] + + cache = ModelCache(raw_data=items, folders=[], name_display_mode="model_name") + + initial = await cache.get_sorted_data("name", "asc") + assert [item["model_name"] for item in initial] == ["Alpha", "Zulu"] + + await cache.update_name_display_mode("file_name") + + # The cached name sort should refresh immediately based on the new mode + updated = await cache.get_sorted_data("name", "asc") + assert [item["file_name"] for item in updated] == ["alpha", "zulu"] diff --git a/tests/services/test_settings_manager.py b/tests/services/test_settings_manager.py index 3b8ecf83..a6f1c0ed 100644 --- a/tests/services/test_settings_manager.py +++ b/tests/services/test_settings_manager.py @@ -4,6 +4,7 @@ import os import pytest +from py.services import service_registry from py.services.settings_manager import SettingsManager from py.utils import settings_paths @@ -100,6 +101,38 @@ def test_delete_setting(manager): assert manager.get("example") is None +def test_model_name_display_setting_notifies_scanners(tmp_path, monkeypatch): + initial = { + "libraries": {"default": {"folder_paths": {}, "default_lora_root": "", "default_checkpoint_root": "", "default_embedding_root": ""}}, + "active_library": "default", + "model_name_display": "model_name", + } + + manager = _create_manager_with_settings(tmp_path, monkeypatch, initial) + + class DummyScanner: + def __init__(self): + self.calls = [] + + async def on_model_name_display_changed(self, mode: str) -> None: + self.calls.append(mode) + + dummy_scanner = DummyScanner() + + def fake_get_service_sync(cls, name): + return dummy_scanner if name == "lora_scanner" else None + + monkeypatch.setattr( + service_registry.ServiceRegistry, + "get_service_sync", + classmethod(fake_get_service_sync), + ) + + manager.set("model_name_display", "file_name") + + assert dummy_scanner.calls == ["file_name"] + + def test_migrates_legacy_settings_file(tmp_path, monkeypatch): legacy_root = tmp_path / "legacy" legacy_root.mkdir()