feat(agent): fix extract_relevant_section false positives, add validation pipeline audit

- extract_relevant_section: raise token threshold >3, verify anchor
  sections contain basename, require 2+ heading token overlaps, skip
  TOC-style headings (markdown links), verify heading section size
- metadata_constructor: parse repo_id,model_name.safetensors format
  so model_path basename matches real filename
- config: replace hardcoded SUPPORTED_BASE_MODELS with dynamic
  init_supported_base_models() using production list_base_models()
- preprocessing_auditor: new Phase 1.5 audit module — fetches each
  README, runs extract_relevant_section + clean_readme_for_llm,
  records stats and flags, saves raw READMEs for cross-reference
- run_validation: integrate audit phase, add --audit-only mode,
  add LLM config consistency check, add ComfyUI root to sys.path
- report_generator: add Preprocessing Audit and Config Warnings
  sections to both markdown and JSON reports
This commit is contained in:
Will Miao
2026-07-05 11:18:48 +08:00
parent dd3aa97d0a
commit 8fb00998a7
6 changed files with 891 additions and 86 deletions

View File

@@ -675,8 +675,10 @@ def extract_relevant_section(
lines = readme_content.split("\n")
n = len(lines)
basename_lower = model_basename.lower()
# Tokens from the basename split on common separators
tokens = {t for t in re.split(r"[_\-.\s]+", basename_lower) if len(t) > 2}
# Tokens from the basename split on common separators.
# Exclude tokens of length ≤ 3 — 2-3 char tokens (e.g. "cry", "myjs")
# are too short to discriminate between different models in collection repos.
tokens = {t for t in re.split(r"[_\-.\s]+", basename_lower) if len(t) > 3}
# ------------------------------------------------------------------
# Strategy 1: Find a download link containing the basename
@@ -695,7 +697,13 @@ def extract_relevant_section(
if m:
aid = m.group(1).lower()
if any(token in aid for token in tokens):
return _extract_section(lines, idx, context_lines)
section = _extract_section(lines, idx, context_lines)
# Verify the extracted section actually mentions the model —
# short anchor IDs can coincidentally match tokens from
# unrelated models (e.g. "myjs" matching a different LoRA).
if basename_lower in section.lower():
return section
# False positive — continue searching
# ------------------------------------------------------------------
# Strategy 3: Find an HTML or markdown heading with overlapping tokens
@@ -712,10 +720,28 @@ def extract_relevant_section(
if mm:
heading_text = mm.group(1)
if heading_text:
# Skip TOC-style entries where the heading text is a markdown
# link or bullet list item (e.g. "### - [model_name](url)").
# These are table-of-contents entries, not real section headers.
stripped = heading_text.strip()
if stripped.startswith("- [") or re.match(r"^\[.+?\]\(.+?\)", stripped):
continue
heading_lower = heading_text.lower()
# Check if any token appears in the heading
if any(token in heading_lower for token in tokens):
return _extract_section(lines, idx, context_lines)
# Require at least 2 token overlaps, or the full basename as a
# substring of the heading. A single 4-5 char token match is
# too weak — e.g. "devil" matching "dante_devil_may_cry" when
# the actual model is "vergil_devil_may_cry", or "image" matching
# a table-of-contents heading.
matching = [t for t in tokens if t in heading_lower]
if len(matching) >= 2 or basename_lower in heading_lower:
section = _extract_section(lines, idx, context_lines)
# Verify the section contains the model name — headings in
# TOC areas can match tokens but produce a tiny irrelevant
# section (e.g. "### - [z_image_turbo](url)" matched by
# tokens "lora" and "turbo").
if basename_lower in section.lower() or len(section) > max(500, context_lines * 20):
return section
# ------------------------------------------------------------------
# Fallback: return FULL readme

View File

@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
_DEFAULT_MODELS_FILE = os.path.expanduser(
"~/Documents/hf_lora_models.txt"
"~/Documents/hf_lora_models_with_safetensors.txt"
)
_DEFAULT_SETTINGS_PATH = os.path.expanduser(
"~/.config/ComfyUI-LoRA-Manager/settings.json"
@@ -36,8 +36,16 @@ CIVITAI_MODEL_TAGS: List[str] = [
"buildings", "objects", "assets", "animal", "action",
]
# Base models recognised as valid values.
SUPPORTED_BASE_MODELS: List[str] = [
# ---------------------------------------------------------------------------
# Base model resolution — dynamically fetched from production code
# ---------------------------------------------------------------------------
# Module-level cache — populated by init_supported_base_models().
# Falls back to a comprehensive hardcoded list when the live fetch fails.
SUPPORTED_BASE_MODELS: List[str] = []
# Fallback base models when the production list_base_models() is unavailable.
_FALLBACK_BASE_MODELS: List[str] = [
"SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper",
"SD 2.0", "SD 2.1",
"SD 3", "SD 3.5", "SD 3.5 Medium", "SD 3.5 Large", "SD 3.5 Large Turbo",
@@ -60,6 +68,33 @@ SUPPORTED_BASE_MODELS: List[str] = [
"Nucleus", "Krea 2",
]
async def init_supported_base_models() -> None:
"""Populate ``SUPPORTED_BASE_MODELS`` from the production codebase.
Calls ``py.agent_cli.list_base_models()`` which merges a hardcoded
fallback with models fetched from the CivitAI API. When the call
fails (e.g. offline, API error), falls back to ``_FALLBACK_BASE_MODELS``.
Must be called from within an async event loop (i.e. during
``run_validation.main()``, not at module level).
"""
try:
from py.agent_cli import list_base_models
models = await list_base_models()
if models:
SUPPORTED_BASE_MODELS[:] = models
logger.info("Loaded %d base models from production code", len(models))
return
logger.warning("list_base_models returned empty list, using fallback")
except Exception as exc:
logger.warning("Failed to load base models from production: %s", exc)
SUPPORTED_BASE_MODELS[:] = _FALLBACK_BASE_MODELS
logger.info("Using fallback base model list (%d entries)", len(SUPPORTED_BASE_MODELS))
# Placeholder values the LLM sometimes emits that should count as "empty".
PLACEHOLDER_VALUES = frozenset({
"none", "null", "n/a", "unknown", "not available",
@@ -71,6 +106,7 @@ PLACEHOLDER_VALUES = frozenset({
# User settings loader
# ---------------------------------------------------------------------------
def load_settings(settings_path: str) -> Dict[str, Any]:
"""Load LoRA Manager settings from *settings_path*.

View File

@@ -1,7 +1,11 @@
"""Construct initial ``.metadata.json`` sidecars for HF model repos.
Each HF repo ID gets a minimal metadata file — no real model file is needed.
The enrichment pipeline reads only the sidecar.
Each HF repo + safetensors pair gets a minimal metadata file — no real model
file is needed. The enrichment pipeline reads only the sidecar.
Data format (one line per entry)::
repo_id, model_name.safetensors
"""
from __future__ import annotations
@@ -9,32 +13,64 @@ from __future__ import annotations
import json
import logging
import os
from typing import Dict, List
from typing import Any, Dict, List, Tuple
from .config import CIVITAI_MODEL_TAGS
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------
def load_repo_ids(path: str, max_models: int | None = None) -> List[str]:
"""Read HF repo IDs from *path* (one per line, ignoring blanks/comments)."""
# A validated entry parsed from the models file:
# (repo_id, safetensors_name)
RepoEntry = Tuple[str, str]
def load_repo_ids(path: str, max_models: int | None = None) -> List[RepoEntry]:
"""Read ``repo_id, safetensors_name`` pairs from *path*.
Format (one per line, blanks and ``#`` comments ignored)::
user/repo-name, lora_zimage_turbo_myjs_alpha01.safetensors
Returns a list of ``(repo_id, safetensors_name)`` tuples.
"""
path = os.path.expanduser(path)
if not os.path.exists(path):
raise FileNotFoundError(f"Models file not found: {path}")
repos: List[str] = []
entries: List[RepoEntry] = []
with open(path, "r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
for raw_line in fh:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
repos.append(line)
# Split on the first comma
if "," not in line:
logger.warning("Skipping malformed line (no comma): %s", raw_line.rstrip())
continue
repo_id, safetensors_name = [part.strip() for part in line.split(",", 1)]
if not repo_id or not safetensors_name:
logger.warning("Skipping malformed line (empty fields): %s", raw_line.rstrip())
continue
if not safetensors_name.lower().endswith(".safetensors"):
logger.warning(
"Skipping line — safetensors_name doesn't end with .safetensors: %s",
raw_line.rstrip(),
)
continue
entries.append((repo_id, safetensors_name))
if max_models is not None and max_models > 0:
repos = repos[:max_models]
entries = entries[:max_models]
logger.info("Loaded %d HF repo IDs from %s", len(repos), path)
return repos
logger.info("Loaded %d HF repo entries from %s", len(entries), path)
return entries
def sanitize_repo_id(repo_id: str) -> str:
@@ -47,9 +83,9 @@ def build_model_dir(output_dir: str, repo_id: str) -> str:
return os.path.join(output_dir, "models", sanitize_repo_id(repo_id))
def build_model_path(model_dir: str) -> str:
"""Return a synthetic model file path (no real file will exist)."""
return os.path.join(model_dir, "model.safetensors")
def build_model_path(model_dir: str, safetensors_name: str) -> str:
"""Return the model file path using the real safetensors filename."""
return os.path.join(model_dir, safetensors_name)
def build_metadata_path(model_path: str) -> str:
@@ -58,8 +94,8 @@ def build_metadata_path(model_path: str) -> str:
This MUST match the convention used by ``MetadataManager`` /
``apply_metadata_updates``, which derives the sidecar path via
``os.path.splitext(model_path)[0] + '.metadata.json'``.
For a model file ``model.safetensors`` the sidecar is
``model.metadata.json`` — *not* ``model.safetensors.metadata.json``.
For a model file ``lora_x.safetensors`` the sidecar is
``lora_x.metadata.json`` — *not* ``lora_x.safetensors.metadata.json``.
"""
return f"{os.path.splitext(model_path)[0]}.metadata.json"
@@ -67,23 +103,32 @@ def build_metadata_path(model_path: str) -> str:
def create_initial_metadata(
output_dir: str,
repo_id: str,
safetensors_name: str,
) -> str:
"""Write a minimal ``.metadata.json`` for *repo_id*.
"""Write a minimal ``.metadata.json`` for *repo_id* + *safetensors_name*.
Args:
output_dir: Root output directory.
repo_id: HuggingFace repo identifier (``user/repo``).
safetensors_name: The specific model file name (e.g.
``lora_zimage_turbo_myjs_alpha01.safetensors``).
Returns the **model path** (the ``.safetensors`` path whose sidecar was
written). The caller passes this path to ``AgentService.execute_skill``.
The basename (filename without extension) will match the real model file,
so ``extract_relevant_section`` can reliably match against the README.
"""
model_dir = build_model_dir(output_dir, repo_id)
os.makedirs(model_dir, exist_ok=True)
model_path = build_model_path(model_dir)
model_path = build_model_path(model_dir, safetensors_name)
metadata_path = build_metadata_path(model_path)
hf_url = f"https://huggingface.co/{repo_id}"
file_name = repo_id.split("/")[-1]
file_name = safetensors_name
metadata: Dict = {
metadata: Dict[str, Any] = {
"file_name": file_name,
"model_name": file_name,
"model_name": safetensors_name,
"file_path": model_path.replace(os.sep, "/"),
"size": 0,
"modified": 0,
@@ -117,32 +162,41 @@ def create_initial_metadata(
def create_all_initial_metadata(
repos: List[str],
entries: List[RepoEntry],
output_dir: str,
*,
skip_existing: bool = True,
) -> List[str]:
"""Create initial metadata for every repo in *repos*.
) -> Tuple[List[str], List[str]]:
"""Create initial metadata for every repo entry.
Returns a list of model paths in the same order as *repos*.
``skip_existing=True`` skips repos whose metadata already exists,
allowing safe re-run.
Args:
entries: List of ``(repo_id, safetensors_name)`` tuples.
output_dir: Root output directory.
skip_existing: If True, skip repos whose metadata already exists.
Returns:
A tuple ``(model_paths, repo_ids)`` — two parallel lists in the same
order as *entries*. This keeps downstream code (enrichment runner,
evaluation engine) unchanged.
"""
model_paths: List[str] = []
for repo_id in repos:
repo_ids: List[str] = []
for repo_id, safetensors_name in entries:
model_dir = build_model_dir(output_dir, repo_id)
model_path = build_model_path(model_dir)
model_path = build_model_path(model_dir, safetensors_name)
metadata_path = build_metadata_path(model_path)
if skip_existing and os.path.exists(metadata_path):
model_paths.append(model_path)
repo_ids.append(repo_id)
continue
model_paths.append(create_initial_metadata(output_dir, repo_id))
model_paths.append(create_initial_metadata(output_dir, repo_id, safetensors_name))
repo_ids.append(repo_id)
logger.info(
"Constructed initial metadata for %d/%d repos",
len(model_paths),
len(repos),
len(entries),
)
return model_paths
return model_paths, repo_ids

View File

@@ -0,0 +1,467 @@
"""Preprocessing audit for the HF metadata enrichment validation pipeline.
Phase 1.5 — runs between Phase 1 (metadata creation) and Phase 2 (enrichment).
Audits the README preprocessing pipeline (section extraction + cleaning)
for each repo in the dataset, capturing intermediate outputs so we can
distinguish between:
(A) Preprocessing failed → LLM never saw the right content
(B) Preprocessing succeeded → LLM/prompt needs improvement
This prevents wasted effort optimizing prompts when the actual problem is
that ``extract_relevant_section`` or ``clean_readme_for_llm`` removed or
misaligned the content the LLM needed.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Tuple
import aiohttp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Audit record
# ---------------------------------------------------------------------------
@dataclass
class AuditRecord:
"""Preprocessing audit for a single repo entry."""
# Identity
repo_id: str
safetensors_name: str
basename: str # filename without .safetensors
# Raw README stats
raw_readme_length: int
raw_readme_line_count: int
has_yaml_frontmatter: bool
yaml_has_base_model: bool
yaml_has_tags: bool
# Section extraction
section_extraction_activated: bool # output < 95% of input length
section_length: int
section_line_count: int
basename_in_section: bool # basename appears in extracted section text
# Cleaning
cleaned_length: int
cleaned_line_count: int
compression_pct: float # (1 - cleaned/raw) * 100
# Widget section (stripped by _strip_widget_section)
widget_section_found: bool
widget_section_length: int
# Flags (list of anomaly descriptions)
flags: List[str] = field(default_factory=list)
# Local file path to the saved raw README (for cross-reference)
readme_file: str = ""
# Staged intermediate output for report detail
raw_readme_preview: str = "" # first 200 chars
section_preview: str = "" # first 300 chars
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_HF_RAW_URL = "https://huggingface.co/{repo_id}/raw/main/README.md"
# Thresholds for flagging
_SECTION_ACTIVATION_RATIO = 0.95
_MIN_CLEANED_LENGTH = 100
_MAX_COMPRESSION_PCT = 99.0
_MIN_SECTION_LINES = 3
# ---------------------------------------------------------------------------
# Module loader — bypasses parent-package __init__ that imports ComfyUI
# ---------------------------------------------------------------------------
_readme_processor_module = None
def _load_readme_processor():
"""Import ``readme_processor`` without triggering ``folder_paths`` import.
The normal import path (``py.services.agent.skills.enrich_hf_metadata.
readme_processor``) triggers ``py.services.agent.__init__`` which
imports ``agent_service.py`` → ``py/config.py`` → ComfyUI's
``folder_paths``, which is not available in standalone mode.
"""
global _readme_processor_module
if _readme_processor_module is not None:
return _readme_processor_module
import importlib.util
_RP_PATH = os.path.join(
os.path.dirname(__file__), # tests/enrich_hf_validation/
"..", "..",
"py", "services", "agent", "skills", "enrich_hf_metadata",
"readme_processor.py",
)
rp_path = os.path.normpath(_RP_PATH)
if not os.path.exists(rp_path):
logger.error("readme_processor.py not found at %s", rp_path)
return None
spec = importlib.util.spec_from_file_location(
"readme_processor", rp_path,
)
if spec is None or spec.loader is None:
logger.error("Could not create spec for readme_processor.py")
return None
mod = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(mod)
except Exception as exc:
logger.error("Failed to load readme_processor.py: %s", exc)
return None
_readme_processor_module = mod
return mod
# ---------------------------------------------------------------------------
# HF README fetcher
# ---------------------------------------------------------------------------
async def _fetch_readme(repo_id: str, session: aiohttp.ClientSession) -> str:
"""Fetch the raw README.md from HuggingFace."""
url = _HF_RAW_URL.format(repo_id=repo_id)
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status == 200:
return await resp.text()
logger.warning("Failed to fetch README for %s: HTTP %d", repo_id, resp.status)
return ""
except (asyncio.TimeoutError, aiohttp.ClientError) as exc:
logger.warning("Failed to fetch README for %s: %s", repo_id, exc)
return ""
# ---------------------------------------------------------------------------
# Analysis helpers
# ---------------------------------------------------------------------------
def _has_yaml_frontmatter(text: str) -> bool:
return bool(text.strip().startswith("---"))
def _extract_yaml_field(text: str, field: str) -> bool:
"""Check if the given YAML field exists in the frontmatter."""
lines = text.split("\n")
if not lines or not lines[0].strip().startswith("---"):
return False
end = 1
while end < len(lines):
if lines[end].strip().startswith("---"):
break
end += 1
if end >= len(lines):
return False
frontmatter = "\n".join(lines[1:end])
pattern = rf"^{field}:"
return bool(re.search(pattern, frontmatter, re.MULTILINE))
def _find_widget_section_length(text: str) -> int:
"""Find the ``widget:`` YAML section and return its length (0 if none)."""
if not _has_yaml_frontmatter(text):
return 0
frontmatter_end = text.find("---", 3)
if frontmatter_end == -1:
return 0
frontmatter = text[3:frontmatter_end]
# Match widget: through to the next top-level key or frontmatter end
m = re.search(r"\nwidget:", frontmatter)
if not m:
return 0
# Length from widget: to end of frontmatter (the next \n\w+: or \n---)
return len(frontmatter[m.start():])
# ---------------------------------------------------------------------------
# Core auditor
# ---------------------------------------------------------------------------
async def run_audit(
entries: List[Tuple[str, str]],
*,
concurrency: int = 10,
readmes_dir: str | None = None,
) -> Tuple[List[AuditRecord], Dict[str, Any]]:
"""Run the preprocessing audit over all repo entries.
Args:
entries: List of ``(repo_id, safetensors_name)``.
concurrency: Max parallel fetches to HuggingFace.
readmes_dir: If set, saves each fetched README as
``{sanitized_repo_id}.md`` in this directory for offline
cross-reference against audit results.
Returns:
Tuple of ``(records, summary)`` where *summary* is a dict with
aggregate statistics.
"""
semaphore = asyncio.Semaphore(concurrency)
records: List[AuditRecord] = []
flag_counter: Dict[str, int] = {}
if readmes_dir:
os.makedirs(readmes_dir, exist_ok=True)
connector = aiohttp.TCPConnector(limit=concurrency)
async with aiohttp.ClientSession(connector=connector) as session:
tasks = [_audit_one(entry, session, semaphore, readmes_dir=readmes_dir) for entry in entries]
gathered = await asyncio.gather(*tasks, return_exceptions=True)
for entry, result in zip(entries, gathered):
if isinstance(result, Exception):
logger.error("Audit failed for %s: %s", entry[0], result)
records.append(
AuditRecord(
repo_id=entry[0],
safetensors_name=entry[1],
basename=os.path.splitext(entry[1])[0],
raw_readme_length=0,
raw_readme_line_count=0,
has_yaml_frontmatter=False,
yaml_has_base_model=False,
yaml_has_tags=False,
section_extraction_activated=False,
section_length=0,
section_line_count=0,
basename_in_section=False,
cleaned_length=0,
cleaned_line_count=0,
compression_pct=0.0,
widget_section_found=False,
widget_section_length=0,
readme_file="",
flags=[f"Audit exception: {result}"],
)
)
continue
# The continue above ensures result is AuditRecord here
assert isinstance(result, AuditRecord)
records.append(result)
for flag in result.flags:
flag_counter[flag] = flag_counter.get(flag, 0) + 1
summary = _build_summary(records, flag_counter)
return records, summary
def _sanitize_repo_id(repo_id: str) -> str:
"""Turn ``user/repo-name`` into a safe filename."""
return repo_id.replace("/", "__").replace(".", "_")
async def _audit_one(
entry: Tuple[str, str],
session: aiohttp.ClientSession,
semaphore: asyncio.Semaphore,
*,
readmes_dir: str | None = None,
) -> AuditRecord:
"""Audit a single repo entry."""
repo_id, safetensors_name = entry
basename = os.path.splitext(safetensors_name)[0]
async with semaphore:
# Import production preprocessing functions.
# Use importlib to bypass py.services.agent.__init__ which triggers
# ComfyUI's folder_paths module (not available in standalone mode).
_rp = _load_readme_processor()
if _rp is None:
return AuditRecord(
repo_id=repo_id,
safetensors_name=safetensors_name,
basename=basename,
raw_readme_length=0, raw_readme_line_count=0,
has_yaml_frontmatter=False, yaml_has_base_model=False, yaml_has_tags=False,
readme_file="",
section_extraction_activated=False, section_length=0, section_line_count=0,
basename_in_section=False, cleaned_length=0, cleaned_line_count=0,
compression_pct=0.0, widget_section_found=False, widget_section_length=0,
flags=["IMPORT_FAILED"],
)
clean_readme_for_llm = _rp.clean_readme_for_llm
extract_relevant_section = _rp.extract_relevant_section
# Step 1: Fetch the raw README
raw_text = await _fetch_readme(repo_id, session)
if not raw_text:
return AuditRecord(
repo_id=repo_id,
safetensors_name=safetensors_name,
basename=basename,
raw_readme_length=0,
raw_readme_line_count=0,
has_yaml_frontmatter=False,
yaml_has_base_model=False,
yaml_has_tags=False,
section_extraction_activated=False,
section_length=0,
section_line_count=0,
basename_in_section=False,
readme_file="",
cleaned_length=0,
cleaned_line_count=0,
compression_pct=0.0,
widget_section_found=False,
widget_section_length=0,
flags=["README_FETCH_FAILED"],
)
# Save the raw README to disk for offline cross-reference
readme_path = ""
if readmes_dir:
safe_name = _sanitize_repo_id(repo_id)
readme_path = os.path.join(readmes_dir, f"{safe_name}.md")
try:
with open(readme_path, "w", encoding="utf-8") as fh:
fh.write(raw_text)
except OSError as exc:
logger.warning("Failed to save README for %s: %s", repo_id, exc)
readme_path = ""
raw_lines = raw_text.split("\n")
raw_len = len(raw_text)
raw_line_count = len(raw_lines)
# Step 2: Analyze raw README
yaml_fm = _has_yaml_frontmatter(raw_text)
yaml_has_bm = _extract_yaml_field(raw_text, "base_model") if yaml_fm else False
yaml_has_tg = _extract_yaml_field(raw_text, "tags") if yaml_fm else False
widget_len = _find_widget_section_length(raw_text)
# Step 3: Section extraction
section = extract_relevant_section(raw_text, basename)
section_len = len(section)
section_line_count = len(section.split("\n"))
section_activated = section_len < raw_len * _SECTION_ACTIVATION_RATIO
basename_in_sec = basename.lower() in section.lower()
# Step 4: Cleaning for LLM
cleaned = clean_readme_for_llm(section)
cleaned_len = len(cleaned)
cleaned_line_count = len(cleaned.split("\n"))
compression_pct = round((1 - cleaned_len / raw_len) * 100, 1) if raw_len else 0.0
# Step 5: Flag anomalies
flags: List[str] = []
if not raw_text.strip():
flags.append("README_EMPTY")
if not yaml_fm:
flags.append("NO_YAML_FRONTMATTER")
if not section_activated:
# Check if basename is extremely short/generic (likely synthetic)
if len(basename) <= 5:
flags.append("BASENAME_TOO_SHORT_SECTION_NOT_EXPECTED")
else:
flags.append("SECTION_EXTRACTION_NOT_ACTIVATED")
elif not basename_in_sec:
flags.append("BASENAME_NOT_IN_EXTRACTED_SECTION")
if widget_len == 0:
# Not necessarily a problem — many repos lack a widget section
pass
if cleaned_len < _MIN_CLEANED_LENGTH:
flags.append("CLEANED_README_TOO_SHORT")
if compression_pct > _MAX_COMPRESSION_PCT:
flags.append("EXTREME_COMPRESSION")
if section_activated and section_line_count < _MIN_SECTION_LINES:
flags.append("SECTION_TOO_SMALL")
return AuditRecord(
repo_id=repo_id,
safetensors_name=safetensors_name,
basename=basename,
raw_readme_length=raw_len,
raw_readme_line_count=raw_line_count,
has_yaml_frontmatter=yaml_fm,
yaml_has_base_model=yaml_has_bm,
yaml_has_tags=yaml_has_tg,
section_extraction_activated=section_activated,
section_length=section_len,
section_line_count=section_line_count,
basename_in_section=basename_in_sec,
cleaned_length=cleaned_len,
cleaned_line_count=cleaned_line_count,
compression_pct=compression_pct,
widget_section_found=widget_len > 0,
widget_section_length=widget_len,
readme_file=readme_path,
flags=flags,
raw_readme_preview=raw_text[:200],
section_preview=section[:300],
)
def _build_summary(
records: List[AuditRecord],
flag_counter: Dict[str, int],
) -> Dict[str, Any]:
"""Aggregate audit statistics."""
n = len(records)
if n == 0:
return {"error": "no records", "model_count": 0}
activated = sum(1 for r in records if r.section_extraction_activated)
basename_hit = sum(1 for r in records if r.basename_in_section)
with_yaml = sum(1 for r in records if r.has_yaml_frontmatter)
with_widget = sum(1 for r in records if r.widget_section_found)
fetch_failed = sum(1 for r in records if "README_FETCH_FAILED" in r.flags)
avg_compression = round(
sum(r.compression_pct for r in records if r.raw_readme_length > 0) / max(n - fetch_failed, 1),
1,
)
avg_cleaned = round(
sum(r.cleaned_length for r in records if r.raw_readme_length > 0) / max(n - fetch_failed, 1),
)
top_flags = sorted(flag_counter.items(), key=lambda x: -x[1])[:10]
return {
"model_count": n,
"fetch_failed_count": fetch_failed,
"section_extraction_activated": activated,
"section_extraction_pct": round(activated / max(n - fetch_failed, 1) * 100, 1),
"basename_in_section": basename_hit,
"basename_in_section_pct": round(basename_hit / max(n - fetch_failed, 1) * 100, 1),
"with_yaml_frontmatter": with_yaml,
"with_yaml_frontmatter_pct": round(with_yaml / max(n - fetch_failed, 1) * 100, 1),
"with_widget_section": with_widget,
"avg_compression_pct": avg_compression,
"avg_cleaned_length": avg_cleaned,
"top_flags": top_flags,
}
def audit_records_to_serializable(records: List[AuditRecord]) -> List[Dict[str, Any]]:
"""Convert AuditRecord dataclasses to plain dicts for JSON serialization."""
return [asdict(r) for r in records]

View File

@@ -15,6 +15,7 @@ import os
from datetime import datetime
from typing import Any, Dict, List
from .config import SUPPORTED_BASE_MODELS
from .evaluation_engine import ScoreRecord
logger = logging.getLogger(__name__)
@@ -56,33 +57,12 @@ def generate_optimisation_suggestions(
for s in scores
if s["raw_values"]["base_model"]
and s["raw_values"]["base_model"] != "Unknown"
and s["raw_values"]["base_model"] not in {
"SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper",
"SD 2.0", "SD 2.1", "SD 3", "SD 3.5", "SD 3.5 Medium",
"SD 3.5 Large", "SD 3.5 Large Turbo",
"SDXL 1.0", "SDXL Lightning", "SDXL Hyper",
"Flux.1 D", "Flux.1 S", "Flux.1 Krea", "Flux.1 Kontext",
"Flux.2 D", "Flux.2 Klein 9B", "Flux.2 Klein 9B-base",
"Flux.2 Klein 4B", "Flux.2 Klein 4B-base",
"AuraFlow", "Chroma", "PixArt a", "PixArt E",
"Hunyuan 1", "Lumina", "Kolors",
"NoobAI", "Illustrious", "Pony", "Pony V7",
"HiDream", "Qwen", "ZImageTurbo", "ZImageBase",
"SVD", "LTXV", "LTXV2", "LTXV 2.3",
"CogVideoX", "Mochi",
"Wan Video", "Wan Video 1.3B t2v", "Wan Video 14B t2v",
"Wan Video 14B i2v 480p", "Wan Video 14B i2v 720p",
"Wan Video 2.2 TI2V-5B", "Wan Video 2.2 T2V-A14B",
"Wan Video 2.2 I2V-A14B",
"Wan Video 2.5 T2V", "Wan Video 2.5 I2V",
"Hunyuan Video", "Anima", "Ernie", "Ernie Turbo",
"Nucleus", "Krea 2",
}
and s["raw_values"]["base_model"] not in set(SUPPORTED_BASE_MODELS)
)
if bm_invalid > 5:
suggestions.append(
"- **base_model 含非标准值 ({} 个)**: LLM 输出了未在 `SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS` "
"中的 base model 名称。建议在 prompt 中强调 \"Use EXACTLY one name from the list\" 并在 "
"- **base_model 含非标准值 ({} 个)**: LLM 输出了未在当前生产系统的 base model 列表 "
"中的名称。建议在 prompt 中强调 \"Use EXACTLY one name from the list\" 并在 "
"`PostProcessor` 中加一层验证过滤,非标准值直接丢弃。".format(bm_invalid)
)
@@ -139,7 +119,8 @@ def generate_optimisation_suggestions(
if ut and ut.get("empty_rate_pct", 0) > 70:
suggestions.append(
"- **usage_tips 空置率极高 ({:.0f}%)**: 这是预期行为。HF 模型卡通常不包含 LoRA "
"强度/CLIP skip 等结构化参数。当前提取策略已合理。若需要可用数据," "可以考虑使用模型类型的通用默认值。".format(
"强度/CLIP skip 等结构化参数。当前提取策略已合理。若需要可用数据,"
"可以考虑使用模型类型的通用默认值。".format(
ut.get("empty_rate_pct", 0)
)
)
@@ -164,8 +145,20 @@ def generate_markdown_report(
scores: List[ScoreRecord],
output_dir: str,
duration_summary: Dict[str, Any] | None = None,
*,
audit_summary: Dict[str, Any] | None = None,
config_warnings: List[str] | None = None,
) -> str:
"""Write ``report.md`` and return its content."""
"""Write ``report.md`` and return its content.
Args:
agg: Aggregate evaluation scores.
scores: Per-model evaluation records.
output_dir: Output directory for the report file.
duration_summary: Optional timing statistics.
audit_summary: Optional preprocessing audit summary (Phase 1.5).
config_warnings: Optional LLM config consistency warnings.
"""
lines: List[str] = []
def wl(text: str = "") -> None:
lines.append(text)
@@ -178,6 +171,60 @@ def generate_markdown_report(
wl(f"Failures: **{agg.get('fail_count', 0)}**")
wl()
# ---- Preprocessing Audit Section ----
if audit_summary and audit_summary.get("model_count", 0) > 0:
wl("## Preprocessing Audit")
wl()
wl(f"| Metric | Value |")
wl(f"|--------|-------|")
wl(f"| Models audited | {audit_summary.get('model_count', 0)} |")
wl(f"| README fetch failed | {audit_summary.get('fetch_failed_count', 0)} |")
wl(f"| Section extraction activated | {_fmt_pct(audit_summary.get('section_extraction_pct', 0))} |")
wl(f"| Basename found in section | {_fmt_pct(audit_summary.get('basename_in_section_pct', 0))} |")
wl(f"| Has YAML frontmatter | {_fmt_pct(audit_summary.get('with_yaml_frontmatter_pct', 0))} |")
wl(f"| Has YAML widget section | {_fmt_pct(audit_summary.get('with_widget_section', 0))} |")
wl(f"| Avg README compression | {audit_summary.get('avg_compression_pct', 0)}% |")
wl(f"| Avg cleaned length | {audit_summary.get('avg_cleaned_length', 0)} chars |")
wl()
if audit_summary.get("top_flags"):
wl("### Audit Flags (most frequent)")
wl()
for flag, count in audit_summary["top_flags"]:
wl(f"- **{flag}**: {count}x")
wl()
wl("**Interpretation:**")
wl()
act_pct = audit_summary.get("section_extraction_pct", 0)
if act_pct < 50:
wl(
"- ⚠️ Section extraction activated for fewer than 50% of repos. "
"This may indicate the basename doesn't match README content, or the "
"repos are mostly single-model (where full README is expected)."
)
else:
wl(
"- ✅ Section extraction is working for most repos — the LLM is "
"receiving focused README sections."
)
if audit_summary.get("basename_in_section_pct", 100) < 80:
wl(
"- ⚠️ The safetensors basename was NOT found in the extracted section "
"for many repos. This could mean the section extraction matched the wrong "
"section, or the README doesn't explicitly reference the filename."
)
wl()
# ---- Config warnings ----
if config_warnings:
wl("## ⚠️ Configuration Warnings")
wl()
for w in config_warnings:
wl(f"- {w}")
wl()
# ---- Duration ----
if duration_summary:
wl("## Timing")
@@ -307,8 +354,21 @@ def save_json_report(
enrichment_results: List[Dict[str, Any]],
output_dir: str,
duration_summary: Dict[str, Any] | None = None,
*,
audit_summary: Dict[str, Any] | None = None,
config_warnings: List[str] | None = None,
) -> str:
"""Write ``report.json`` and return the path."""
"""Write ``report.json`` and return the path.
Args:
agg: Aggregate evaluation scores.
scores: Per-model evaluation records.
enrichment_results: Raw enrichment phase results.
output_dir: Output directory.
duration_summary: Optional timing statistics.
audit_summary: Optional preprocessing audit summary.
config_warnings: Optional LLM config consistency warnings.
"""
report: Dict[str, Any] = {
"metadata": {
"generated_at": datetime.now().isoformat(),
@@ -319,6 +379,11 @@ def save_json_report(
"per_model_scores": scores,
"enrichment_results": enrichment_results,
}
if audit_summary:
report["preprocessing_audit"] = audit_summary
if config_warnings:
report["config_warnings"] = config_warnings
path = os.path.join(output_dir, "report.json")
with open(path, "w", encoding="utf-8") as fh:
json.dump(report, fh, indent=2, ensure_ascii=False)

View File

@@ -3,7 +3,7 @@
Usage::
# Full run (100 models, serial, ~1-2 h)
# Full run (44 models, serial, ~1-2 h)
python -m tests.enrich_hf_validation.run_validation \\
--output /tmp/hf_enrich_validation
@@ -13,6 +13,9 @@ Usage::
# Resume from a previous partial run
python -m tests.enrich_hf_validation.run_validation --resume
# Audit preprocessing only (no LLM calls, fast)
python -m tests.enrich_hf_validation.run_validation --audit-only
# Custom settings file
python -m tests.enrich_hf_validation.run_validation \\
--settings /custom/path/settings.json
@@ -27,7 +30,7 @@ import logging
import os
import sys
import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
# Ensure the project root is on sys.path so that ``from py import ...`` works.
_PROJECT_ROOT = os.path.normpath(
@@ -36,8 +39,18 @@ _PROJECT_ROOT = os.path.normpath(
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from tests.enrich_hf_validation.config import load_settings
# Add ComfyUI root to sys.path so ``folder_paths`` can be imported.
# Project layout: ComfyUI/custom_nodes/ComfyUI-Lora-Manager/
_COMFYUI_ROOT = os.path.normpath(os.path.join(_PROJECT_ROOT, "..", ".."))
if _COMFYUI_ROOT not in sys.path:
sys.path.insert(0, _COMFYUI_ROOT)
from tests.enrich_hf_validation.config import (
init_supported_base_models,
load_settings,
)
from tests.enrich_hf_validation.metadata_constructor import (
RepoEntry,
create_all_initial_metadata,
load_repo_ids,
)
@@ -46,6 +59,10 @@ from tests.enrich_hf_validation.evaluation_engine import (
aggregate_scores,
evaluate_batch,
)
from tests.enrich_hf_validation.preprocessing_auditor import (
audit_records_to_serializable,
run_audit,
)
from tests.enrich_hf_validation.report_generator import (
generate_markdown_report,
save_json_report,
@@ -70,8 +87,8 @@ def _parse_args(argv: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--models",
default="~/Documents/hf_lora_models.txt",
help="Path to the HF repo ID list (one per line)",
default="~/Documents/hf_lora_models_with_safetensors.txt",
help="Path to the HF repo entries file (format: repo_id, model_name.safetensors per line)",
)
parser.add_argument(
"--settings",
@@ -99,6 +116,11 @@ def _parse_args(argv: List[str]) -> argparse.Namespace:
action="store_true",
help="Skip enrichment phase (evaluate existing metadata only)",
)
parser.add_argument(
"--audit-only",
action="store_true",
help="Run preprocessing audit only (no enrichment, no evaluation)",
)
parser.add_argument(
"--timeout",
type=int,
@@ -125,6 +147,117 @@ def _phase_header(label: str) -> None:
print(sep, file=sys.stderr)
# ---------------------------------------------------------------------------
# Read back LLM config after enrichment (for consistency reporting)
# ---------------------------------------------------------------------------
def _get_actual_llm_config() -> Dict[str, str]:
"""Read what LLMService is actually using, if initialized.
Only meaningful when called AFTER enrichment has started (i.e. after
``AgentService.get_instance()`` has been called).
"""
try:
from py.services.llm_service import LLMService
instance = LLMService._instance
if instance is None:
return {"status": "not initialized"}
cfg = instance._get_config()
return {
"provider": cfg.get("provider", ""),
"model": cfg.get("model", ""),
"api_base": cfg.get("api_base", ""),
}
except Exception as exc:
return {"status": f"error: {exc}"}
def _compare_llm_config(
pipeline_cfg: Dict[str, Any],
actual_cfg: Dict[str, str],
) -> List[str]:
"""Compare pipeline-loaded vs LLMService-used config.
Returns warning messages if they differ.
"""
warnings: List[str] = []
if not actual_cfg or actual_cfg.get("status", "") == "not initialized":
warnings.append(
"LLMService was not initialized during this run — cannot verify "
"config consistency."
)
return warnings
field_map = [
("llm_provider", "provider"),
("llm_model", "model"),
("llm_api_base", "api_base"),
]
for pipeline_key, llm_key in field_map:
pv = (pipeline_cfg.get(pipeline_key) or "").strip()
lv = (actual_cfg.get(llm_key) or "").strip()
if pv and lv and pv != lv:
warnings.append(
f"LLM config mismatch: --settings has '{pv}' for {pipeline_key}, "
f"but LLMService uses '{lv}'. "
f"The pipeline's --settings path ({pipeline_cfg.get('settings_path', '?')}) "
"may differ from where SettingsManager reads."
)
if not warnings and actual_cfg:
warnings.append(
"✅ LLM config matches between pipeline --settings and LLMService."
)
return warnings
# ---------------------------------------------------------------------------
# Phase 1.5: preprocessing audit
# ---------------------------------------------------------------------------
async def _run_preprocessing_audit(
entries: List[RepoEntry],
output_dir: str,
) -> Dict[str, Any]:
"""Execute the preprocessing audit and save results."""
_phase_header("Preprocessing audit")
print(f" Auditing {len(entries)} repos ...", file=sys.stderr)
readmes_dir = os.path.join(output_dir, "readmes")
t0 = time.perf_counter()
records, summary = await run_audit(entries, readmes_dir=readmes_dir)
elapsed = time.perf_counter() - t0
# Save audit data
audit_path = os.path.join(output_dir, "preprocessing_audit.json")
with open(audit_path, "w", encoding="utf-8") as fh:
json.dump(
{
"summary": summary,
"records": audit_records_to_serializable(records),
},
fh,
indent=2,
ensure_ascii=False,
)
print(f" Audit complete: {len(records)} repos in {elapsed:.0f}s", file=sys.stderr)
print(f" Section extraction activated: {summary.get('section_extraction_pct', 0)}%", file=sys.stderr)
print(f" Basename in extracted section: {summary.get('basename_in_section_pct', 0)}%", file=sys.stderr)
print(f" Avg compression: {summary.get('avg_compression_pct', 0)}%", file=sys.stderr)
print(f" Avg cleaned length: {summary.get('avg_cleaned_length', 0)} chars", file=sys.stderr)
print(f" Audit data: {audit_path}", file=sys.stderr)
if summary.get("top_flags"):
print(" Top flags:", file=sys.stderr)
for flag, count in summary["top_flags"][:5]:
print(f" - {flag}: {count}x", file=sys.stderr)
return summary
async def _run_enrichment(
model_paths: List[str],
repos: List[str],
@@ -170,7 +303,7 @@ def _collect_enriched_metadata(
errors, metadata.
"""
enriched: List[Dict[str, Any]] = []
# Build a lookup from repo_id enrichment result
# Build a lookup from repo_id to enrichment result
result_lookup: Dict[str, Dict[str, Any]] = {}
for r in results:
result_lookup[r["repo_id"]] = r
@@ -214,29 +347,43 @@ async def main(argv: List[str]) -> int:
output_dir = os.path.abspath(os.path.expanduser(args.output))
os.makedirs(output_dir, exist_ok=True)
# ---- Phase 0: Initialise shared state ----
_phase_header("Initialise")
settings = load_settings(args.settings)
logger.info(
"LLM config: provider=%s model=%s api_base=%s",
"LLM config from --settings: provider=%s model=%s api_base=%s",
settings["llm_provider"],
settings["llm_model"],
settings["llm_api_base"],
)
# Load the production base model list (replaces the old hardcoded list)
await init_supported_base_models()
# ---- Phase 1: Load repo IDs & construct initial metadata ----
_phase_header("Load repo IDs & construct initial metadata")
repos = load_repo_ids(args.models, max_models=args.sample if args.sample > 0 else None)
model_paths = create_all_initial_metadata(
repos, output_dir, skip_existing=True,
# ---- Load entries ----
_phase_header("Load repo entries & construct initial metadata")
entries = load_repo_ids(args.models, max_models=args.sample if args.sample > 0 else None)
model_paths, repo_ids = create_all_initial_metadata(
entries, output_dir, skip_existing=True,
)
print(f" {len(model_paths)} repos ready", file=sys.stderr)
# ---- Phase 1.5: Preprocessing audit ----
audit_summary: Dict[str, Any] = {}
t_start = time.perf_counter()
audit_summary = await _run_preprocessing_audit(entries, output_dir)
if args.audit_only:
total_wall = time.perf_counter() - t_start
print(f"\n Audit-only done in {total_wall:.0f}s", file=sys.stderr)
print(f" Audit data: {output_dir}/preprocessing_audit.json", file=sys.stderr)
return 0
# ---- Phase 2: Enrichment ----
enrichment_results: List[Dict[str, Any]] = []
t_start = time.perf_counter()
if not args.no_enrich:
_phase_header("Enrich metadata via LLM")
enrichment_out = await _run_enrichment(
model_paths, repos, output_dir, args.timeout, args.verbose,
model_paths, repo_ids, output_dir, args.timeout, args.verbose,
)
enrichment_results = enrichment_out["results"]
else:
@@ -246,7 +393,7 @@ async def main(argv: List[str]) -> int:
# ---- Phase 3: Evaluation ----
_phase_header("Evaluate enriched metadata")
enriched = _collect_enriched_metadata(model_paths, repos, enrichment_results)
enriched = _collect_enriched_metadata(model_paths, repo_ids, enrichment_results)
scores = evaluate_batch(enriched)
agg = aggregate_scores(scores)
print(
@@ -274,8 +421,18 @@ async def main(argv: List[str]) -> int:
"max_s": round(max(durations), 1),
}
save_json_report(agg, scores, enrichment_results, output_dir, duration_summary)
generate_markdown_report(agg, scores, output_dir, duration_summary)
# Check LLM config consistency after enrichment (LLMService is now initialized)
actual_llm_cfg = _get_actual_llm_config()
config_warnings = _compare_llm_config(settings, actual_llm_cfg)
save_json_report(
agg, scores, enrichment_results, output_dir, duration_summary,
audit_summary=audit_summary, config_warnings=config_warnings,
)
generate_markdown_report(
agg, scores, output_dir, duration_summary,
audit_summary=audit_summary, config_warnings=config_warnings,
)
# ---- Final summary ----
total_wall = time.perf_counter() - t_start