Merge pull request #574 from willmiao/codex/add-model-name-display-setting

feat: respect model name display preference in model cache
This commit is contained in:
pixelpaws
2025-10-16 09:21:19 +08:00
committed by GitHub
5 changed files with 262 additions and 8 deletions

View File

@@ -15,6 +15,9 @@ SUPPORTED_SORT_MODES = [
('size', 'desc'),
]
DISPLAY_NAME_MODES = {"model_name", "file_name"}
@dataclass
class ModelCache:
"""Cache structure for model data with extensible sorting."""
@@ -22,16 +25,39 @@ class ModelCache:
raw_data: List[Dict]
folders: List[str]
version_index: Dict[int, Dict] = field(default_factory=dict)
name_display_mode: str = "model_name"
def __post_init__(self):
self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list
self._last_sort: Tuple[str, str] = (None, None)
self._last_sorted_data: List[Dict] = []
self.name_display_mode = self._normalize_display_mode(self.name_display_mode)
# Default sort on init
asyncio.create_task(self.resort())
self.rebuild_version_index()
@staticmethod
def _normalize_display_mode(value: Optional[str]) -> str:
if isinstance(value, str) and value in DISPLAY_NAME_MODES:
return value
return "model_name"
def _get_display_name(self, item: Dict) -> str:
"""Return the value used for name-based sorting based on display settings."""
if self.name_display_mode == "file_name":
primary = item.get("file_name", "")
fallback = item.get("model_name", "")
else:
primary = item.get("model_name", "")
fallback = item.get("file_name", "")
candidate = primary or fallback or ""
if isinstance(candidate, str):
return candidate
return str(candidate)
@staticmethod
def _normalize_version_id(value: Any) -> Optional[int]:
"""Normalize a potential version identifier into an integer."""
@@ -101,10 +127,10 @@ class ModelCache:
"""Sort data by sort_key and order"""
reverse = (order == 'desc')
if sort_key == 'name':
# Natural sort by model_name, case-insensitive
# Natural sort by configured display name, case-insensitive
return natsorted(
data,
key=lambda x: x['model_name'].lower(),
key=lambda x: self._get_display_name(x).lower(),
reverse=reverse
)
elif sort_key == 'date':
@@ -135,6 +161,20 @@ class ModelCache:
self._last_sorted_data = sorted_data
return sorted_data
async def update_name_display_mode(self, display_mode: str) -> None:
"""Update the display mode used for name sorting and refresh cached results."""
normalized = self._normalize_display_mode(display_mode)
async with self._lock:
if self.name_display_mode == normalized:
return
self.name_display_mode = normalized
if self._last_sort[0] == 'name':
sort_key, order = self._last_sort
self._last_sorted_data = self._sort_data(self.raw_data, sort_key, order)
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
"""Update preview_url for a specific model in all cached data

View File

@@ -18,6 +18,7 @@ from .model_lifecycle_service import delete_model_artifacts
from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager
from .persistent_model_cache import get_persistent_cache
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -81,6 +82,13 @@ class ModelScanner:
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._persistent_cache = get_persistent_cache()
self._name_display_mode = self._resolve_name_display_mode()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
self._loop = loop
self.loop = loop
self._initialized = True
# Register this service
@@ -94,6 +102,7 @@ class ModelScanner:
self._tags_count = {}
self._excluded_models = []
self._is_initializing = False
self._name_display_mode = self._resolve_name_display_mode()
try:
loop = asyncio.get_running_loop()
@@ -101,8 +110,30 @@ class ModelScanner:
loop = None
if loop and not loop.is_closed():
self._loop = loop
self.loop = loop
loop.create_task(self.initialize_in_background())
def _resolve_name_display_mode(self) -> str:
"""Return the configured display mode for name sorting."""
try:
manager = get_settings_manager()
except Exception: # pragma: no cover - fallback to defaults
return "model_name"
value = manager.get("model_name_display", "model_name")
return ModelCache._normalize_display_mode(value)
async def on_model_name_display_changed(self, display_mode: str) -> None:
"""Handle updates to the model name display preference."""
normalized = ModelCache._normalize_display_mode(display_mode)
self._name_display_mode = normalized
if self._cache is not None:
await self._cache.update_name_display_mode(normalized)
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
@@ -211,7 +242,8 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
# Set initializing flag to true
@@ -516,7 +548,8 @@ class ModelScanner:
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
# If force refresh is requested, initialize the cache directly
@@ -549,7 +582,8 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
finally:
self._is_initializing = False # Unset flag
@@ -837,7 +871,8 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=list(scan_result.raw_data),
folders=[]
folders=[],
name_display_mode=self._name_display_mode,
)
else:
self._cache.raw_data = list(scan_result.raw_data)

View File

@@ -1,10 +1,11 @@
import asyncio
import copy
import json
import os
import logging
from datetime import datetime, timezone
from threading import Lock
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
from typing import Any, Awaitable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from ..utils.constants import DEFAULT_PRIORITY_TAG_CONFIG
from ..utils.settings_paths import ensure_settings_file
@@ -465,6 +466,8 @@ class SettingsManager:
self._update_active_library_entry(default_checkpoint_root=str(value))
elif key == 'default_embedding_root':
self._update_active_library_entry(default_embedding_root=str(value))
elif key == 'model_name_display':
self._notify_model_name_display_change(value)
self._save_settings()
def delete(self, key: str) -> None:
@@ -474,6 +477,81 @@ class SettingsManager:
self._save_settings()
logger.info(f"Deleted setting: {key}")
def _notify_model_name_display_change(self, value: Any) -> None:
"""Trigger cache resorting when the model name display preference updates."""
try:
from .service_registry import ServiceRegistry # type: ignore
except Exception: # pragma: no cover - registry optional in some contexts
return
display_mode = value if isinstance(value, str) else "model_name"
pending: List[Tuple[Optional[asyncio.AbstractEventLoop], Awaitable[Any]]] = []
def _resolve_service_loop(service: Any) -> Optional[asyncio.AbstractEventLoop]:
loop = getattr(service, "loop", None)
if loop is None:
loop = getattr(service, "_loop", None)
return loop if isinstance(loop, asyncio.AbstractEventLoop) else None
for service_name in (
"lora_scanner",
"checkpoint_scanner",
"embedding_scanner",
"recipe_scanner",
):
service = ServiceRegistry.get_service_sync(service_name)
if not service or not hasattr(service, "on_model_name_display_changed"):
continue
try:
result = service.on_model_name_display_changed(display_mode)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug(
"Service %s failed to schedule name display update: %s",
service_name,
exc,
)
continue
if asyncio.iscoroutine(result):
service_loop = _resolve_service_loop(service)
pending.append((service_loop, result))
if not pending:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
for service_loop, coroutine in pending:
target_loop = service_loop or loop
if target_loop is None:
try:
asyncio.run(coroutine)
except RuntimeError:
logger.debug("Skipping name display update due to missing event loop")
continue
if loop is not None and target_loop is loop:
target_loop.create_task(coroutine)
continue
if target_loop.is_running():
try:
asyncio.run_coroutine_threadsafe(coroutine, target_loop)
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to dispatch name display update: %s", exc)
continue
try:
asyncio.run(coroutine)
except RuntimeError:
logger.debug("Skipping name display update due to closed loop")
def _save_settings(self) -> None:
"""Save settings to file"""
try:

View File

@@ -0,0 +1,41 @@
import pytest
from py.services.model_cache import ModelCache
@pytest.mark.asyncio
async def test_name_sort_respects_file_name_display():
items = [
{"model_name": "Bravo", "file_name": "zulu", "folder": "", "size": 1, "modified": 1},
{"model_name": "Alpha", "file_name": "alpha", "folder": "", "size": 1, "modified": 1},
{"model_name": "Charlie", "file_name": "echo", "folder": "", "size": 1, "modified": 1},
]
cache = ModelCache(raw_data=items, folders=[], name_display_mode="file_name")
sorted_items = await cache.get_sorted_data("name", "asc")
assert [item["file_name"] for item in sorted_items] == [
"alpha",
"echo",
"zulu",
]
@pytest.mark.asyncio
async def test_update_name_display_mode_resorts_cached_name_order():
items = [
{"model_name": "Zulu", "file_name": "alpha", "folder": "", "size": 1, "modified": 1},
{"model_name": "Alpha", "file_name": "zulu", "folder": "", "size": 1, "modified": 1},
]
cache = ModelCache(raw_data=items, folders=[], name_display_mode="model_name")
initial = await cache.get_sorted_data("name", "asc")
assert [item["model_name"] for item in initial] == ["Alpha", "Zulu"]
await cache.update_name_display_mode("file_name")
# The cached name sort should refresh immediately based on the new mode
updated = await cache.get_sorted_data("name", "asc")
assert [item["file_name"] for item in updated] == ["alpha", "zulu"]

View File

@@ -1,9 +1,12 @@
import asyncio
import copy
import threading
import json
import os
import pytest
from py.services import service_registry
from py.services.settings_manager import SettingsManager
from py.utils import settings_paths
@@ -100,6 +103,63 @@ def test_delete_setting(manager):
assert manager.get("example") is None
def test_model_name_display_setting_notifies_scanners(tmp_path, monkeypatch):
initial = {
"libraries": {"default": {"folder_paths": {}, "default_lora_root": "", "default_checkpoint_root": "", "default_embedding_root": ""}},
"active_library": "default",
"model_name_display": "model_name",
}
manager = _create_manager_with_settings(tmp_path, monkeypatch, initial)
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()
class DummyScanner:
def __init__(self):
self.calls = []
self.loop = loop
async def on_model_name_display_changed(self, mode: str) -> None:
self.calls.append(mode)
dummy_scanner = DummyScanner()
dispatched_loops = []
futures = []
original_run_coroutine_threadsafe = asyncio.run_coroutine_threadsafe
def tracking_run_coroutine_threadsafe(coro, target_loop):
dispatched_loops.append(target_loop)
future = original_run_coroutine_threadsafe(coro, target_loop)
futures.append(future)
return future
def fake_get_service_sync(cls, name):
return dummy_scanner if name == "lora_scanner" else None
monkeypatch.setattr(
service_registry.ServiceRegistry,
"get_service_sync",
classmethod(fake_get_service_sync),
)
monkeypatch.setattr(asyncio, "run_coroutine_threadsafe", tracking_run_coroutine_threadsafe)
try:
manager.set("model_name_display", "file_name")
for future in futures:
future.result(timeout=1)
assert dummy_scanner.calls == ["file_name"]
assert dispatched_loops == [dummy_scanner.loop]
finally:
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=1)
loop.close()
def test_migrates_legacy_settings_file(tmp_path, monkeypatch):
legacy_root = tmp_path / "legacy"
legacy_root.mkdir()