mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-05 17:01:16 -03:00
refactor(agent): extract shared scanner iteration into _find_model_entry
_Previous_ _find_scanner_for_model and identify_model_type contained ~25 lines of identical scanner-iteration + path-matching logic. Factor it into _find_model_entry() so a new scanner type or edge-case fix can't drift apart.
This commit is contained in:
@@ -41,14 +41,12 @@ SCANNER_TYPE_MAP: dict[str, str] = {
|
||||
SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys())
|
||||
|
||||
|
||||
async def _find_scanner_for_model(
|
||||
async def _find_model_entry(
|
||||
model_path: str,
|
||||
) -> tuple[object, object] | tuple[None, None]:
|
||||
"""Find the (scanner, cache_entry) responsible for *model_path*.
|
||||
|
||||
Iterates all known scanner types and returns the first one whose cache
|
||||
contains the given path. Returns ``(None, None)`` when no scanner
|
||||
claims the model.
|
||||
) -> tuple[object, object, str | None] | tuple[None, None, None]:
|
||||
"""Iterate all scanners and return the first (scanner, entry, getter_name)
|
||||
that owns *model_path*. Returns ``(None, None, None)`` when no scanner
|
||||
claims it.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
@@ -64,47 +62,31 @@ async def _find_scanner_for_model(
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
if os.path.normpath(entry.get("file_path", "")) == normalized:
|
||||
return scanner, entry
|
||||
return scanner, entry, getter_name
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Scanner %s check failed for %s: %s",
|
||||
getter_name,
|
||||
model_path,
|
||||
exc,
|
||||
getter_name, model_path, exc,
|
||||
)
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
|
||||
async def _find_scanner_for_model(
|
||||
model_path: str,
|
||||
) -> tuple[object, object] | tuple[None, None]:
|
||||
"""Find the (scanner, cache_entry) responsible for *model_path*."""
|
||||
scanner, entry, _ = await _find_model_entry(model_path)
|
||||
return scanner, entry
|
||||
|
||||
|
||||
async def identify_model_type(model_path: str) -> str:
|
||||
"""Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or
|
||||
``\"embedding\"``) for *model_path*.
|
||||
|
||||
Iterates all known scanners; the first scanner that claims the path
|
||||
determines the type. Falls back to ``\"lora\"`` when unknown.
|
||||
Falls back to ``\"lora\"`` when unknown.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
normalized = os.path.normpath(model_path)
|
||||
for getter_name in SCANNER_GETTER_NAMES:
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
try:
|
||||
scanner = await getter()
|
||||
if scanner is None:
|
||||
continue
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
if os.path.normpath(entry.get("file_path", "")) == normalized:
|
||||
return SCANNER_TYPE_MAP[getter_name]
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"identify_model_type scanner %s error for %s: %s",
|
||||
getter_name,
|
||||
model_path,
|
||||
exc,
|
||||
)
|
||||
return "lora"
|
||||
_, _, getter_name = await _find_model_entry(model_path)
|
||||
return SCANNER_TYPE_MAP[getter_name] if getter_name else "lora"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user