refactor(agent): extract shared scanner iteration into _find_model_entry

_Previous_ _find_scanner_for_model and identify_model_type contained ~25 lines
of identical scanner-iteration + path-matching logic.  Factor it into
_find_model_entry() so a new scanner type or edge-case fix can't drift apart.
This commit is contained in:
Will Miao
2026-07-05 18:03:57 +08:00
parent 51c0135250
commit e3e944911b

View File

@@ -41,14 +41,12 @@ SCANNER_TYPE_MAP: dict[str, str] = {
SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys()) SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys())
async def _find_scanner_for_model( async def _find_model_entry(
model_path: str, model_path: str,
) -> tuple[object, object] | tuple[None, None]: ) -> tuple[object, object, str | None] | tuple[None, None, None]:
"""Find the (scanner, cache_entry) responsible for *model_path*. """Iterate all scanners and return the first (scanner, entry, getter_name)
that owns *model_path*. Returns ``(None, None, None)`` when no scanner
Iterates all known scanner types and returns the first one whose cache claims it.
contains the given path. Returns ``(None, None)`` when no scanner
claims the model.
""" """
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
@@ -64,47 +62,31 @@ async def _find_scanner_for_model(
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
for entry in cache.raw_data: for entry in cache.raw_data:
if os.path.normpath(entry.get("file_path", "")) == normalized: if os.path.normpath(entry.get("file_path", "")) == normalized:
return scanner, entry return scanner, entry, getter_name
except Exception as exc: except Exception as exc:
logger.debug( logger.debug(
"Scanner %s check failed for %s: %s", "Scanner %s check failed for %s: %s",
getter_name, getter_name, model_path, exc,
model_path,
exc,
) )
return None, None return None, None, None
async def _find_scanner_for_model(
model_path: str,
) -> tuple[object, object] | tuple[None, None]:
"""Find the (scanner, cache_entry) responsible for *model_path*."""
scanner, entry, _ = await _find_model_entry(model_path)
return scanner, entry
async def identify_model_type(model_path: str) -> str: async def identify_model_type(model_path: str) -> str:
"""Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or """Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or
``\"embedding\"``) for *model_path*. ``\"embedding\"``) for *model_path*.
Iterates all known scanners; the first scanner that claims the path Falls back to ``\"lora\"`` when unknown.
determines the type. Falls back to ``\"lora\"`` when unknown.
""" """
from ..services.service_registry import ServiceRegistry _, _, getter_name = await _find_model_entry(model_path)
return SCANNER_TYPE_MAP[getter_name] if getter_name else "lora"
normalized = os.path.normpath(model_path)
for getter_name in SCANNER_GETTER_NAMES:
getter = getattr(ServiceRegistry, getter_name, None)
if getter is None:
continue
try:
scanner = await getter()
if scanner is None:
continue
cache = await scanner.get_cached_data()
for entry in cache.raw_data:
if os.path.normpath(entry.get("file_path", "")) == normalized:
return SCANNER_TYPE_MAP[getter_name]
except Exception as exc:
logger.debug(
"identify_model_type scanner %s error for %s: %s",
getter_name,
model_path,
exc,
)
return "lora"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------