mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Merge pull request #574 from willmiao/codex/add-model-name-display-setting
feat: respect model name display preference in model cache
This commit is contained in:
@@ -15,6 +15,9 @@ SUPPORTED_SORT_MODES = [
|
|||||||
('size', 'desc'),
|
('size', 'desc'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DISPLAY_NAME_MODES = {"model_name", "file_name"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
"""Cache structure for model data with extensible sorting."""
|
"""Cache structure for model data with extensible sorting."""
|
||||||
@@ -22,16 +25,39 @@ class ModelCache:
|
|||||||
raw_data: List[Dict]
|
raw_data: List[Dict]
|
||||||
folders: List[str]
|
folders: List[str]
|
||||||
version_index: Dict[int, Dict] = field(default_factory=dict)
|
version_index: Dict[int, Dict] = field(default_factory=dict)
|
||||||
|
name_display_mode: str = "model_name"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
# Cache for last sort: (sort_key, order) -> sorted list
|
# Cache for last sort: (sort_key, order) -> sorted list
|
||||||
self._last_sort: Tuple[str, str] = (None, None)
|
self._last_sort: Tuple[str, str] = (None, None)
|
||||||
self._last_sorted_data: List[Dict] = []
|
self._last_sorted_data: List[Dict] = []
|
||||||
|
self.name_display_mode = self._normalize_display_mode(self.name_display_mode)
|
||||||
# Default sort on init
|
# Default sort on init
|
||||||
asyncio.create_task(self.resort())
|
asyncio.create_task(self.resort())
|
||||||
self.rebuild_version_index()
|
self.rebuild_version_index()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_display_mode(value: Optional[str]) -> str:
|
||||||
|
if isinstance(value, str) and value in DISPLAY_NAME_MODES:
|
||||||
|
return value
|
||||||
|
return "model_name"
|
||||||
|
|
||||||
|
def _get_display_name(self, item: Dict) -> str:
|
||||||
|
"""Return the value used for name-based sorting based on display settings."""
|
||||||
|
|
||||||
|
if self.name_display_mode == "file_name":
|
||||||
|
primary = item.get("file_name", "")
|
||||||
|
fallback = item.get("model_name", "")
|
||||||
|
else:
|
||||||
|
primary = item.get("model_name", "")
|
||||||
|
fallback = item.get("file_name", "")
|
||||||
|
|
||||||
|
candidate = primary or fallback or ""
|
||||||
|
if isinstance(candidate, str):
|
||||||
|
return candidate
|
||||||
|
return str(candidate)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_version_id(value: Any) -> Optional[int]:
|
def _normalize_version_id(value: Any) -> Optional[int]:
|
||||||
"""Normalize a potential version identifier into an integer."""
|
"""Normalize a potential version identifier into an integer."""
|
||||||
@@ -101,10 +127,10 @@ class ModelCache:
|
|||||||
"""Sort data by sort_key and order"""
|
"""Sort data by sort_key and order"""
|
||||||
reverse = (order == 'desc')
|
reverse = (order == 'desc')
|
||||||
if sort_key == 'name':
|
if sort_key == 'name':
|
||||||
# Natural sort by model_name, case-insensitive
|
# Natural sort by configured display name, case-insensitive
|
||||||
return natsorted(
|
return natsorted(
|
||||||
data,
|
data,
|
||||||
key=lambda x: x['model_name'].lower(),
|
key=lambda x: self._get_display_name(x).lower(),
|
||||||
reverse=reverse
|
reverse=reverse
|
||||||
)
|
)
|
||||||
elif sort_key == 'date':
|
elif sort_key == 'date':
|
||||||
@@ -135,6 +161,20 @@ class ModelCache:
|
|||||||
self._last_sorted_data = sorted_data
|
self._last_sorted_data = sorted_data
|
||||||
return sorted_data
|
return sorted_data
|
||||||
|
|
||||||
|
async def update_name_display_mode(self, display_mode: str) -> None:
|
||||||
|
"""Update the display mode used for name sorting and refresh cached results."""
|
||||||
|
|
||||||
|
normalized = self._normalize_display_mode(display_mode)
|
||||||
|
async with self._lock:
|
||||||
|
if self.name_display_mode == normalized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.name_display_mode = normalized
|
||||||
|
|
||||||
|
if self._last_sort[0] == 'name':
|
||||||
|
sort_key, order = self._last_sort
|
||||||
|
self._last_sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||||
|
|
||||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||||
"""Update preview_url for a specific model in all cached data
|
"""Update preview_url for a specific model in all cached data
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from .model_lifecycle_service import delete_model_artifacts
|
|||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
from .persistent_model_cache import get_persistent_cache
|
from .persistent_model_cache import get_persistent_cache
|
||||||
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -81,6 +82,13 @@ class ModelScanner:
|
|||||||
self._is_initializing = False # Flag to track initialization state
|
self._is_initializing = False # Flag to track initialization state
|
||||||
self._excluded_models = [] # List to track excluded models
|
self._excluded_models = [] # List to track excluded models
|
||||||
self._persistent_cache = get_persistent_cache()
|
self._persistent_cache = get_persistent_cache()
|
||||||
|
self._name_display_mode = self._resolve_name_display_mode()
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
self._loop = loop
|
||||||
|
self.loop = loop
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
# Register this service
|
# Register this service
|
||||||
@@ -94,6 +102,7 @@ class ModelScanner:
|
|||||||
self._tags_count = {}
|
self._tags_count = {}
|
||||||
self._excluded_models = []
|
self._excluded_models = []
|
||||||
self._is_initializing = False
|
self._is_initializing = False
|
||||||
|
self._name_display_mode = self._resolve_name_display_mode()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@@ -101,8 +110,30 @@ class ModelScanner:
|
|||||||
loop = None
|
loop = None
|
||||||
|
|
||||||
if loop and not loop.is_closed():
|
if loop and not loop.is_closed():
|
||||||
|
self._loop = loop
|
||||||
|
self.loop = loop
|
||||||
loop.create_task(self.initialize_in_background())
|
loop.create_task(self.initialize_in_background())
|
||||||
|
|
||||||
|
def _resolve_name_display_mode(self) -> str:
|
||||||
|
"""Return the configured display mode for name sorting."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
manager = get_settings_manager()
|
||||||
|
except Exception: # pragma: no cover - fallback to defaults
|
||||||
|
return "model_name"
|
||||||
|
|
||||||
|
value = manager.get("model_name_display", "model_name")
|
||||||
|
return ModelCache._normalize_display_mode(value)
|
||||||
|
|
||||||
|
async def on_model_name_display_changed(self, display_mode: str) -> None:
|
||||||
|
"""Handle updates to the model name display preference."""
|
||||||
|
|
||||||
|
normalized = ModelCache._normalize_display_mode(display_mode)
|
||||||
|
self._name_display_mode = normalized
|
||||||
|
|
||||||
|
if self._cache is not None:
|
||||||
|
await self._cache.update_name_display_mode(normalized)
|
||||||
|
|
||||||
async def _register_service(self):
|
async def _register_service(self):
|
||||||
"""Register this instance with the ServiceRegistry"""
|
"""Register this instance with the ServiceRegistry"""
|
||||||
service_name = f"{self.model_type}_scanner"
|
service_name = f"{self.model_type}_scanner"
|
||||||
@@ -211,7 +242,8 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
self._cache = ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
folders=[]
|
folders=[],
|
||||||
|
name_display_mode=self._name_display_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set initializing flag to true
|
# Set initializing flag to true
|
||||||
@@ -516,7 +548,8 @@ class ModelScanner:
|
|||||||
if self._cache is None and not force_refresh:
|
if self._cache is None and not force_refresh:
|
||||||
return ModelCache(
|
return ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
folders=[]
|
folders=[],
|
||||||
|
name_display_mode=self._name_display_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If force refresh is requested, initialize the cache directly
|
# If force refresh is requested, initialize the cache directly
|
||||||
@@ -549,7 +582,8 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
self._cache = ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
folders=[]
|
folders=[],
|
||||||
|
name_display_mode=self._name_display_mode,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._is_initializing = False # Unset flag
|
self._is_initializing = False # Unset flag
|
||||||
@@ -837,7 +871,8 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
self._cache = ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=list(scan_result.raw_data),
|
raw_data=list(scan_result.raw_data),
|
||||||
folders=[]
|
folders=[],
|
||||||
|
name_display_mode=self._name_display_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._cache.raw_data = list(scan_result.raw_data)
|
self._cache.raw_data = list(scan_result.raw_data)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
|
from typing import Any, Awaitable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..utils.constants import DEFAULT_PRIORITY_TAG_CONFIG
|
from ..utils.constants import DEFAULT_PRIORITY_TAG_CONFIG
|
||||||
from ..utils.settings_paths import ensure_settings_file
|
from ..utils.settings_paths import ensure_settings_file
|
||||||
@@ -465,6 +466,8 @@ class SettingsManager:
|
|||||||
self._update_active_library_entry(default_checkpoint_root=str(value))
|
self._update_active_library_entry(default_checkpoint_root=str(value))
|
||||||
elif key == 'default_embedding_root':
|
elif key == 'default_embedding_root':
|
||||||
self._update_active_library_entry(default_embedding_root=str(value))
|
self._update_active_library_entry(default_embedding_root=str(value))
|
||||||
|
elif key == 'model_name_display':
|
||||||
|
self._notify_model_name_display_change(value)
|
||||||
self._save_settings()
|
self._save_settings()
|
||||||
|
|
||||||
def delete(self, key: str) -> None:
|
def delete(self, key: str) -> None:
|
||||||
@@ -474,6 +477,81 @@ class SettingsManager:
|
|||||||
self._save_settings()
|
self._save_settings()
|
||||||
logger.info(f"Deleted setting: {key}")
|
logger.info(f"Deleted setting: {key}")
|
||||||
|
|
||||||
|
def _notify_model_name_display_change(self, value: Any) -> None:
|
||||||
|
"""Trigger cache resorting when the model name display preference updates."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .service_registry import ServiceRegistry # type: ignore
|
||||||
|
except Exception: # pragma: no cover - registry optional in some contexts
|
||||||
|
return
|
||||||
|
|
||||||
|
display_mode = value if isinstance(value, str) else "model_name"
|
||||||
|
pending: List[Tuple[Optional[asyncio.AbstractEventLoop], Awaitable[Any]]] = []
|
||||||
|
|
||||||
|
def _resolve_service_loop(service: Any) -> Optional[asyncio.AbstractEventLoop]:
|
||||||
|
loop = getattr(service, "loop", None)
|
||||||
|
if loop is None:
|
||||||
|
loop = getattr(service, "_loop", None)
|
||||||
|
return loop if isinstance(loop, asyncio.AbstractEventLoop) else None
|
||||||
|
|
||||||
|
for service_name in (
|
||||||
|
"lora_scanner",
|
||||||
|
"checkpoint_scanner",
|
||||||
|
"embedding_scanner",
|
||||||
|
"recipe_scanner",
|
||||||
|
):
|
||||||
|
service = ServiceRegistry.get_service_sync(service_name)
|
||||||
|
if not service or not hasattr(service, "on_model_name_display_changed"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = service.on_model_name_display_changed(display_mode)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.debug(
|
||||||
|
"Service %s failed to schedule name display update: %s",
|
||||||
|
service_name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
service_loop = _resolve_service_loop(service)
|
||||||
|
pending.append((service_loop, result))
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
for service_loop, coroutine in pending:
|
||||||
|
target_loop = service_loop or loop
|
||||||
|
|
||||||
|
if target_loop is None:
|
||||||
|
try:
|
||||||
|
asyncio.run(coroutine)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.debug("Skipping name display update due to missing event loop")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if loop is not None and target_loop is loop:
|
||||||
|
target_loop.create_task(coroutine)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if target_loop.is_running():
|
||||||
|
try:
|
||||||
|
asyncio.run_coroutine_threadsafe(coroutine, target_loop)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.debug("Failed to dispatch name display update: %s", exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(coroutine)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.debug("Skipping name display update due to closed loop")
|
||||||
|
|
||||||
def _save_settings(self) -> None:
|
def _save_settings(self) -> None:
|
||||||
"""Save settings to file"""
|
"""Save settings to file"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
41
tests/services/test_model_cache.py
Normal file
41
tests/services/test_model_cache.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.model_cache import ModelCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_name_sort_respects_file_name_display():
|
||||||
|
items = [
|
||||||
|
{"model_name": "Bravo", "file_name": "zulu", "folder": "", "size": 1, "modified": 1},
|
||||||
|
{"model_name": "Alpha", "file_name": "alpha", "folder": "", "size": 1, "modified": 1},
|
||||||
|
{"model_name": "Charlie", "file_name": "echo", "folder": "", "size": 1, "modified": 1},
|
||||||
|
]
|
||||||
|
|
||||||
|
cache = ModelCache(raw_data=items, folders=[], name_display_mode="file_name")
|
||||||
|
|
||||||
|
sorted_items = await cache.get_sorted_data("name", "asc")
|
||||||
|
|
||||||
|
assert [item["file_name"] for item in sorted_items] == [
|
||||||
|
"alpha",
|
||||||
|
"echo",
|
||||||
|
"zulu",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_name_display_mode_resorts_cached_name_order():
|
||||||
|
items = [
|
||||||
|
{"model_name": "Zulu", "file_name": "alpha", "folder": "", "size": 1, "modified": 1},
|
||||||
|
{"model_name": "Alpha", "file_name": "zulu", "folder": "", "size": 1, "modified": 1},
|
||||||
|
]
|
||||||
|
|
||||||
|
cache = ModelCache(raw_data=items, folders=[], name_display_mode="model_name")
|
||||||
|
|
||||||
|
initial = await cache.get_sorted_data("name", "asc")
|
||||||
|
assert [item["model_name"] for item in initial] == ["Alpha", "Zulu"]
|
||||||
|
|
||||||
|
await cache.update_name_display_mode("file_name")
|
||||||
|
|
||||||
|
# The cached name sort should refresh immediately based on the new mode
|
||||||
|
updated = await cache.get_sorted_data("name", "asc")
|
||||||
|
assert [item["file_name"] for item in updated] == ["alpha", "zulu"]
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import threading
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from py.services import service_registry
|
||||||
from py.services.settings_manager import SettingsManager
|
from py.services.settings_manager import SettingsManager
|
||||||
from py.utils import settings_paths
|
from py.utils import settings_paths
|
||||||
|
|
||||||
@@ -100,6 +103,63 @@ def test_delete_setting(manager):
|
|||||||
assert manager.get("example") is None
|
assert manager.get("example") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_name_display_setting_notifies_scanners(tmp_path, monkeypatch):
|
||||||
|
initial = {
|
||||||
|
"libraries": {"default": {"folder_paths": {}, "default_lora_root": "", "default_checkpoint_root": "", "default_embedding_root": ""}},
|
||||||
|
"active_library": "default",
|
||||||
|
"model_name_display": "model_name",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager = _create_manager_with_settings(tmp_path, monkeypatch, initial)
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
thread = threading.Thread(target=loop.run_forever, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
self.loop = loop
|
||||||
|
|
||||||
|
async def on_model_name_display_changed(self, mode: str) -> None:
|
||||||
|
self.calls.append(mode)
|
||||||
|
|
||||||
|
dummy_scanner = DummyScanner()
|
||||||
|
|
||||||
|
dispatched_loops = []
|
||||||
|
futures = []
|
||||||
|
original_run_coroutine_threadsafe = asyncio.run_coroutine_threadsafe
|
||||||
|
|
||||||
|
def tracking_run_coroutine_threadsafe(coro, target_loop):
|
||||||
|
dispatched_loops.append(target_loop)
|
||||||
|
future = original_run_coroutine_threadsafe(coro, target_loop)
|
||||||
|
futures.append(future)
|
||||||
|
return future
|
||||||
|
|
||||||
|
def fake_get_service_sync(cls, name):
|
||||||
|
return dummy_scanner if name == "lora_scanner" else None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
service_registry.ServiceRegistry,
|
||||||
|
"get_service_sync",
|
||||||
|
classmethod(fake_get_service_sync),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(asyncio, "run_coroutine_threadsafe", tracking_run_coroutine_threadsafe)
|
||||||
|
|
||||||
|
try:
|
||||||
|
manager.set("model_name_display", "file_name")
|
||||||
|
|
||||||
|
for future in futures:
|
||||||
|
future.result(timeout=1)
|
||||||
|
|
||||||
|
assert dummy_scanner.calls == ["file_name"]
|
||||||
|
assert dispatched_loops == [dummy_scanner.loop]
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(loop.stop)
|
||||||
|
thread.join(timeout=1)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
def test_migrates_legacy_settings_file(tmp_path, monkeypatch):
|
def test_migrates_legacy_settings_file(tmp_path, monkeypatch):
|
||||||
legacy_root = tmp_path / "legacy"
|
legacy_root = tmp_path / "legacy"
|
||||||
legacy_root.mkdir()
|
legacy_root.mkdir()
|
||||||
|
|||||||
Reference in New Issue
Block a user