feat(model-cache): respect model name display preference

This commit is contained in:
pixelpaws
2025-10-16 07:01:04 +08:00
parent 3627840fe9
commit 3c0feb23ba
5 changed files with 201 additions and 7 deletions

View File

@@ -15,6 +15,9 @@ SUPPORTED_SORT_MODES = [
('size', 'desc'),
]
DISPLAY_NAME_MODES = {"model_name", "file_name"}
@dataclass
class ModelCache:
"""Cache structure for model data with extensible sorting."""
@@ -22,16 +25,39 @@ class ModelCache:
raw_data: List[Dict]
folders: List[str]
version_index: Dict[int, Dict] = field(default_factory=dict)
name_display_mode: str = "model_name"
def __post_init__(self):
self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list
self._last_sort: Tuple[str, str] = (None, None)
self._last_sorted_data: List[Dict] = []
self.name_display_mode = self._normalize_display_mode(self.name_display_mode)
# Default sort on init
asyncio.create_task(self.resort())
self.rebuild_version_index()
@staticmethod
def _normalize_display_mode(value: Optional[str]) -> str:
if isinstance(value, str) and value in DISPLAY_NAME_MODES:
return value
return "model_name"
def _get_display_name(self, item: Dict) -> str:
"""Return the value used for name-based sorting based on display settings."""
if self.name_display_mode == "file_name":
primary = item.get("file_name", "")
fallback = item.get("model_name", "")
else:
primary = item.get("model_name", "")
fallback = item.get("file_name", "")
candidate = primary or fallback or ""
if isinstance(candidate, str):
return candidate
return str(candidate)
@staticmethod
def _normalize_version_id(value: Any) -> Optional[int]:
"""Normalize a potential version identifier into an integer."""
@@ -101,10 +127,10 @@ class ModelCache:
"""Sort data by sort_key and order"""
reverse = (order == 'desc')
if sort_key == 'name':
# Natural sort by model_name, case-insensitive
# Natural sort by configured display name, case-insensitive
return natsorted(
data,
key=lambda x: x['model_name'].lower(),
key=lambda x: self._get_display_name(x).lower(),
reverse=reverse
)
elif sort_key == 'date':
@@ -135,6 +161,20 @@ class ModelCache:
self._last_sorted_data = sorted_data
return sorted_data
async def update_name_display_mode(self, display_mode: str) -> None:
"""Update the display mode used for name sorting and refresh cached results."""
normalized = self._normalize_display_mode(display_mode)
async with self._lock:
if self.name_display_mode == normalized:
return
self.name_display_mode = normalized
if self._last_sort[0] == 'name':
sort_key, order = self._last_sort
self._last_sorted_data = self._sort_data(self.raw_data, sort_key, order)
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
"""Update preview_url for a specific model in all cached data

View File

@@ -18,6 +18,7 @@ from .model_lifecycle_service import delete_model_artifacts
from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager
from .persistent_model_cache import get_persistent_cache
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -81,6 +82,7 @@ class ModelScanner:
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._persistent_cache = get_persistent_cache()
self._name_display_mode = self._resolve_name_display_mode()
self._initialized = True
# Register this service
@@ -94,6 +96,7 @@ class ModelScanner:
self._tags_count = {}
self._excluded_models = []
self._is_initializing = False
self._name_display_mode = self._resolve_name_display_mode()
try:
loop = asyncio.get_running_loop()
@@ -102,7 +105,27 @@ class ModelScanner:
if loop and not loop.is_closed():
loop.create_task(self.initialize_in_background())
def _resolve_name_display_mode(self) -> str:
"""Return the configured display mode for name sorting."""
try:
manager = get_settings_manager()
except Exception: # pragma: no cover - fallback to defaults
return "model_name"
value = manager.get("model_name_display", "model_name")
return ModelCache._normalize_display_mode(value)
async def on_model_name_display_changed(self, display_mode: str) -> None:
"""Handle updates to the model name display preference."""
normalized = ModelCache._normalize_display_mode(display_mode)
self._name_display_mode = normalized
if self._cache is not None:
await self._cache.update_name_display_mode(normalized)
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
@@ -211,7 +234,8 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
# Set initializing flag to true
@@ -516,7 +540,8 @@ class ModelScanner:
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
# If force refresh is requested, initialize the cache directly
@@ -549,7 +574,8 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
finally:
self._is_initializing = False # Unset flag
@@ -837,7 +863,8 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=list(scan_result.raw_data),
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
else:
self._cache.raw_data = list(scan_result.raw_data)

View File

@@ -1,3 +1,4 @@
import asyncio
import copy
import json
import os
@@ -465,6 +466,8 @@ class SettingsManager:
self._update_active_library_entry(default_checkpoint_root=str(value))
elif key == 'default_embedding_root':
self._update_active_library_entry(default_embedding_root=str(value))
elif key == 'model_name_display':
self._notify_model_name_display_change(value)
self._save_settings()
def delete(self, key: str) -> None:
@@ -474,6 +477,56 @@ class SettingsManager:
self._save_settings()
logger.info(f"Deleted setting: {key}")
def _notify_model_name_display_change(self, value: Any) -> None:
"""Trigger cache resorting when the model name display preference updates."""
try:
from .service_registry import ServiceRegistry # type: ignore
except Exception: # pragma: no cover - registry optional in some contexts
return
display_mode = value if isinstance(value, str) else "model_name"
coroutines = []
for service_name in (
"lora_scanner",
"checkpoint_scanner",
"embedding_scanner",
"recipe_scanner",
):
service = ServiceRegistry.get_service_sync(service_name)
if not service or not hasattr(service, "on_model_name_display_changed"):
continue
try:
result = service.on_model_name_display_changed(display_mode)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug(
"Service %s failed to schedule name display update: %s",
service_name,
exc,
)
continue
if asyncio.iscoroutine(result):
coroutines.append(result)
if not coroutines:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
for coroutine in coroutines:
try:
asyncio.run(coroutine)
except RuntimeError:
# If event loop is already running in another thread, skip execution
logger.debug("Skipping name display update due to running loop")
else:
for coroutine in coroutines:
loop.create_task(coroutine)
def _save_settings(self) -> None:
"""Save settings to file"""
try:

View File

@@ -0,0 +1,41 @@
import pytest
from py.services.model_cache import ModelCache
@pytest.mark.asyncio
async def test_name_sort_respects_file_name_display():
items = [
{"model_name": "Bravo", "file_name": "zulu", "folder": "", "size": 1, "modified": 1},
{"model_name": "Alpha", "file_name": "alpha", "folder": "", "size": 1, "modified": 1},
{"model_name": "Charlie", "file_name": "echo", "folder": "", "size": 1, "modified": 1},
]
cache = ModelCache(raw_data=items, folders=[], name_display_mode="file_name")
sorted_items = await cache.get_sorted_data("name", "asc")
assert [item["file_name"] for item in sorted_items] == [
"alpha",
"echo",
"zulu",
]
@pytest.mark.asyncio
async def test_update_name_display_mode_resorts_cached_name_order():
items = [
{"model_name": "Zulu", "file_name": "alpha", "folder": "", "size": 1, "modified": 1},
{"model_name": "Alpha", "file_name": "zulu", "folder": "", "size": 1, "modified": 1},
]
cache = ModelCache(raw_data=items, folders=[], name_display_mode="model_name")
initial = await cache.get_sorted_data("name", "asc")
assert [item["model_name"] for item in initial] == ["Alpha", "Zulu"]
await cache.update_name_display_mode("file_name")
# The cached name sort should refresh immediately based on the new mode
updated = await cache.get_sorted_data("name", "asc")
assert [item["file_name"] for item in updated] == ["alpha", "zulu"]

View File

@@ -4,6 +4,7 @@ import os
import pytest
from py.services import service_registry
from py.services.settings_manager import SettingsManager
from py.utils import settings_paths
@@ -100,6 +101,38 @@ def test_delete_setting(manager):
assert manager.get("example") is None
def test_model_name_display_setting_notifies_scanners(tmp_path, monkeypatch):
initial = {
"libraries": {"default": {"folder_paths": {}, "default_lora_root": "", "default_checkpoint_root": "", "default_embedding_root": ""}},
"active_library": "default",
"model_name_display": "model_name",
}
manager = _create_manager_with_settings(tmp_path, monkeypatch, initial)
class DummyScanner:
def __init__(self):
self.calls = []
async def on_model_name_display_changed(self, mode: str) -> None:
self.calls.append(mode)
dummy_scanner = DummyScanner()
def fake_get_service_sync(cls, name):
return dummy_scanner if name == "lora_scanner" else None
monkeypatch.setattr(
service_registry.ServiceRegistry,
"get_service_sync",
classmethod(fake_get_service_sync),
)
manager.set("model_name_display", "file_name")
assert dummy_scanner.calls == ["file_name"]
def test_migrates_legacy_settings_file(tmp_path, monkeypatch):
legacy_root = tmp_path / "legacy"
legacy_root.mkdir()