mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
fix(settings): dispatch name display updates on original loop
This commit is contained in:
@@ -83,6 +83,12 @@ class ModelScanner:
|
|||||||
self._excluded_models = [] # List to track excluded models
|
self._excluded_models = [] # List to track excluded models
|
||||||
self._persistent_cache = get_persistent_cache()
|
self._persistent_cache = get_persistent_cache()
|
||||||
self._name_display_mode = self._resolve_name_display_mode()
|
self._name_display_mode = self._resolve_name_display_mode()
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
self._loop = loop
|
||||||
|
self.loop = loop
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
# Register this service
|
# Register this service
|
||||||
@@ -104,6 +110,8 @@ class ModelScanner:
|
|||||||
loop = None
|
loop = None
|
||||||
|
|
||||||
if loop and not loop.is_closed():
|
if loop and not loop.is_closed():
|
||||||
|
self._loop = loop
|
||||||
|
self.loop = loop
|
||||||
loop.create_task(self.initialize_in_background())
|
loop.create_task(self.initialize_in_background())
|
||||||
|
|
||||||
def _resolve_name_display_mode(self) -> str:
|
def _resolve_name_display_mode(self) -> str:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
|
from typing import Any, Awaitable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ..utils.constants import DEFAULT_PRIORITY_TAG_CONFIG
|
from ..utils.constants import DEFAULT_PRIORITY_TAG_CONFIG
|
||||||
from ..utils.settings_paths import ensure_settings_file
|
from ..utils.settings_paths import ensure_settings_file
|
||||||
@@ -486,7 +486,13 @@ class SettingsManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
display_mode = value if isinstance(value, str) else "model_name"
|
display_mode = value if isinstance(value, str) else "model_name"
|
||||||
coroutines = []
|
pending: List[Tuple[Optional[asyncio.AbstractEventLoop], Awaitable[Any]]] = []
|
||||||
|
|
||||||
|
def _resolve_service_loop(service: Any) -> Optional[asyncio.AbstractEventLoop]:
|
||||||
|
loop = getattr(service, "loop", None)
|
||||||
|
if loop is None:
|
||||||
|
loop = getattr(service, "_loop", None)
|
||||||
|
return loop if isinstance(loop, asyncio.AbstractEventLoop) else None
|
||||||
|
|
||||||
for service_name in (
|
for service_name in (
|
||||||
"lora_scanner",
|
"lora_scanner",
|
||||||
@@ -509,23 +515,42 @@ class SettingsManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
coroutines.append(result)
|
service_loop = _resolve_service_loop(service)
|
||||||
|
pending.append((service_loop, result))
|
||||||
|
|
||||||
if not coroutines:
|
if not pending:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
for coroutine in coroutines:
|
loop = None
|
||||||
|
|
||||||
|
for service_loop, coroutine in pending:
|
||||||
|
target_loop = service_loop or loop
|
||||||
|
|
||||||
|
if target_loop is None:
|
||||||
try:
|
try:
|
||||||
asyncio.run(coroutine)
|
asyncio.run(coroutine)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# If event loop is already running in another thread, skip execution
|
logger.debug("Skipping name display update due to missing event loop")
|
||||||
logger.debug("Skipping name display update due to running loop")
|
continue
|
||||||
else:
|
|
||||||
for coroutine in coroutines:
|
if loop is not None and target_loop is loop:
|
||||||
loop.create_task(coroutine)
|
target_loop.create_task(coroutine)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if target_loop.is_running():
|
||||||
|
try:
|
||||||
|
asyncio.run_coroutine_threadsafe(coroutine, target_loop)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.debug("Failed to dispatch name display update: %s", exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(coroutine)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.debug("Skipping name display update due to closed loop")
|
||||||
|
|
||||||
def _save_settings(self) -> None:
|
def _save_settings(self) -> None:
|
||||||
"""Save settings to file"""
|
"""Save settings to file"""
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import threading
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -110,15 +112,30 @@ def test_model_name_display_setting_notifies_scanners(tmp_path, monkeypatch):
|
|||||||
|
|
||||||
manager = _create_manager_with_settings(tmp_path, monkeypatch, initial)
|
manager = _create_manager_with_settings(tmp_path, monkeypatch, initial)
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
thread = threading.Thread(target=loop.run_forever, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
class DummyScanner:
|
class DummyScanner:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
self.loop = loop
|
||||||
|
|
||||||
async def on_model_name_display_changed(self, mode: str) -> None:
|
async def on_model_name_display_changed(self, mode: str) -> None:
|
||||||
self.calls.append(mode)
|
self.calls.append(mode)
|
||||||
|
|
||||||
dummy_scanner = DummyScanner()
|
dummy_scanner = DummyScanner()
|
||||||
|
|
||||||
|
dispatched_loops = []
|
||||||
|
futures = []
|
||||||
|
original_run_coroutine_threadsafe = asyncio.run_coroutine_threadsafe
|
||||||
|
|
||||||
|
def tracking_run_coroutine_threadsafe(coro, target_loop):
|
||||||
|
dispatched_loops.append(target_loop)
|
||||||
|
future = original_run_coroutine_threadsafe(coro, target_loop)
|
||||||
|
futures.append(future)
|
||||||
|
return future
|
||||||
|
|
||||||
def fake_get_service_sync(cls, name):
|
def fake_get_service_sync(cls, name):
|
||||||
return dummy_scanner if name == "lora_scanner" else None
|
return dummy_scanner if name == "lora_scanner" else None
|
||||||
|
|
||||||
@@ -127,10 +144,20 @@ def test_model_name_display_setting_notifies_scanners(tmp_path, monkeypatch):
|
|||||||
"get_service_sync",
|
"get_service_sync",
|
||||||
classmethod(fake_get_service_sync),
|
classmethod(fake_get_service_sync),
|
||||||
)
|
)
|
||||||
|
monkeypatch.setattr(asyncio, "run_coroutine_threadsafe", tracking_run_coroutine_threadsafe)
|
||||||
|
|
||||||
manager.set("model_name_display", "file_name")
|
try:
|
||||||
|
manager.set("model_name_display", "file_name")
|
||||||
|
|
||||||
assert dummy_scanner.calls == ["file_name"]
|
for future in futures:
|
||||||
|
future.result(timeout=1)
|
||||||
|
|
||||||
|
assert dummy_scanner.calls == ["file_name"]
|
||||||
|
assert dispatched_loops == [dummy_scanner.loop]
|
||||||
|
finally:
|
||||||
|
loop.call_soon_threadsafe(loop.stop)
|
||||||
|
thread.join(timeout=1)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
def test_migrates_legacy_settings_file(tmp_path, monkeypatch):
|
def test_migrates_legacy_settings_file(tmp_path, monkeypatch):
|
||||||
|
|||||||
Reference in New Issue
Block a user