mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-05 17:01:16 -03:00
feat(agent): fix extract_relevant_section false positives, add validation pipeline audit
- extract_relevant_section: raise token threshold >3, verify anchor sections contain basename, require 2+ heading token overlaps, skip TOC-style headings (markdown links), verify heading section size - metadata_constructor: parse repo_id,model_name.safetensors format so model_path basename matches real filename - config: replace hardcoded SUPPORTED_BASE_MODELS with dynamic init_supported_base_models() using production list_base_models() - preprocessing_auditor: new Phase 1.5 audit module — fetches each README, runs extract_relevant_section + clean_readme_for_llm, records stats and flags, saves raw READMEs for cross-reference - run_validation: integrate audit phase, add --audit-only mode, add LLM config consistency check, add ComfyUI root to sys.path - report_generator: add Preprocessing Audit and Config Warnings sections to both markdown and JSON reports
This commit is contained in:
@@ -675,8 +675,10 @@ def extract_relevant_section(
|
||||
lines = readme_content.split("\n")
|
||||
n = len(lines)
|
||||
basename_lower = model_basename.lower()
|
||||
# Tokens from the basename split on common separators
|
||||
tokens = {t for t in re.split(r"[_\-.\s]+", basename_lower) if len(t) > 2}
|
||||
# Tokens from the basename split on common separators.
|
||||
# Exclude tokens of length ≤ 3 — 2-3 char tokens (e.g. "cry", "myjs")
|
||||
# are too short to discriminate between different models in collection repos.
|
||||
tokens = {t for t in re.split(r"[_\-.\s]+", basename_lower) if len(t) > 3}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Strategy 1: Find a download link containing the basename
|
||||
@@ -695,7 +697,13 @@ def extract_relevant_section(
|
||||
if m:
|
||||
aid = m.group(1).lower()
|
||||
if any(token in aid for token in tokens):
|
||||
return _extract_section(lines, idx, context_lines)
|
||||
section = _extract_section(lines, idx, context_lines)
|
||||
# Verify the extracted section actually mentions the model —
|
||||
# short anchor IDs can coincidentally match tokens from
|
||||
# unrelated models (e.g. "myjs" matching a different LoRA).
|
||||
if basename_lower in section.lower():
|
||||
return section
|
||||
# False positive — continue searching
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Strategy 3: Find an HTML or markdown heading with overlapping tokens
|
||||
@@ -712,10 +720,28 @@ def extract_relevant_section(
|
||||
if mm:
|
||||
heading_text = mm.group(1)
|
||||
if heading_text:
|
||||
# Skip TOC-style entries where the heading text is a markdown
|
||||
# link or bullet list item (e.g. "### - [model_name](url)").
|
||||
# These are table-of-contents entries, not real section headers.
|
||||
stripped = heading_text.strip()
|
||||
if stripped.startswith("- [") or re.match(r"^\[.+?\]\(.+?\)", stripped):
|
||||
continue
|
||||
|
||||
heading_lower = heading_text.lower()
|
||||
# Check if any token appears in the heading
|
||||
if any(token in heading_lower for token in tokens):
|
||||
return _extract_section(lines, idx, context_lines)
|
||||
# Require at least 2 token overlaps, or the full basename as a
|
||||
# substring of the heading. A single 4-5 char token match is
|
||||
# too weak — e.g. "devil" matching "dante_devil_may_cry" when
|
||||
# the actual model is "vergil_devil_may_cry", or "image" matching
|
||||
# a table-of-contents heading.
|
||||
matching = [t for t in tokens if t in heading_lower]
|
||||
if len(matching) >= 2 or basename_lower in heading_lower:
|
||||
section = _extract_section(lines, idx, context_lines)
|
||||
# Verify the section contains the model name — headings in
|
||||
# TOC areas can match tokens but produce a tiny irrelevant
|
||||
# section (e.g. "### - [z_image_turbo](url)" matched by
|
||||
# tokens "lora" and "turbo").
|
||||
if basename_lower in section.lower() or len(section) > max(500, context_lines * 20):
|
||||
return section
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fallback: return FULL readme
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_MODELS_FILE = os.path.expanduser(
|
||||
"~/Documents/hf_lora_models.txt"
|
||||
"~/Documents/hf_lora_models_with_safetensors.txt"
|
||||
)
|
||||
_DEFAULT_SETTINGS_PATH = os.path.expanduser(
|
||||
"~/.config/ComfyUI-LoRA-Manager/settings.json"
|
||||
@@ -36,8 +36,16 @@ CIVITAI_MODEL_TAGS: List[str] = [
|
||||
"buildings", "objects", "assets", "animal", "action",
|
||||
]
|
||||
|
||||
# Base models recognised as valid values.
|
||||
SUPPORTED_BASE_MODELS: List[str] = [
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base model resolution — dynamically fetched from production code
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Module-level cache — populated by init_supported_base_models().
|
||||
# Falls back to a comprehensive hardcoded list when the live fetch fails.
|
||||
SUPPORTED_BASE_MODELS: List[str] = []
|
||||
|
||||
# Fallback base models when the production list_base_models() is unavailable.
|
||||
_FALLBACK_BASE_MODELS: List[str] = [
|
||||
"SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper",
|
||||
"SD 2.0", "SD 2.1",
|
||||
"SD 3", "SD 3.5", "SD 3.5 Medium", "SD 3.5 Large", "SD 3.5 Large Turbo",
|
||||
@@ -60,6 +68,33 @@ SUPPORTED_BASE_MODELS: List[str] = [
|
||||
"Nucleus", "Krea 2",
|
||||
]
|
||||
|
||||
|
||||
async def init_supported_base_models() -> None:
|
||||
"""Populate ``SUPPORTED_BASE_MODELS`` from the production codebase.
|
||||
|
||||
Calls ``py.agent_cli.list_base_models()`` which merges a hardcoded
|
||||
fallback with models fetched from the CivitAI API. When the call
|
||||
fails (e.g. offline, API error), falls back to ``_FALLBACK_BASE_MODELS``.
|
||||
|
||||
Must be called from within an async event loop (i.e. during
|
||||
``run_validation.main()``, not at module level).
|
||||
"""
|
||||
try:
|
||||
from py.agent_cli import list_base_models
|
||||
|
||||
models = await list_base_models()
|
||||
if models:
|
||||
SUPPORTED_BASE_MODELS[:] = models
|
||||
logger.info("Loaded %d base models from production code", len(models))
|
||||
return
|
||||
logger.warning("list_base_models returned empty list, using fallback")
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load base models from production: %s", exc)
|
||||
|
||||
SUPPORTED_BASE_MODELS[:] = _FALLBACK_BASE_MODELS
|
||||
logger.info("Using fallback base model list (%d entries)", len(SUPPORTED_BASE_MODELS))
|
||||
|
||||
|
||||
# Placeholder values the LLM sometimes emits that should count as "empty".
|
||||
PLACEHOLDER_VALUES = frozenset({
|
||||
"none", "null", "n/a", "unknown", "not available",
|
||||
@@ -71,6 +106,7 @@ PLACEHOLDER_VALUES = frozenset({
|
||||
# User settings loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_settings(settings_path: str) -> Dict[str, Any]:
|
||||
"""Load LoRA Manager settings from *settings_path*.
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""Construct initial ``.metadata.json`` sidecars for HF model repos.
|
||||
|
||||
Each HF repo ID gets a minimal metadata file — no real model file is needed.
|
||||
The enrichment pipeline reads only the sidecar.
|
||||
Each HF repo + safetensors pair gets a minimal metadata file — no real model
|
||||
file is needed. The enrichment pipeline reads only the sidecar.
|
||||
|
||||
Data format (one line per entry)::
|
||||
|
||||
repo_id, model_name.safetensors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -9,32 +13,64 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from .config import CIVITAI_MODEL_TAGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_repo_ids(path: str, max_models: int | None = None) -> List[str]:
|
||||
"""Read HF repo IDs from *path* (one per line, ignoring blanks/comments)."""
|
||||
# A validated entry parsed from the models file:
|
||||
# (repo_id, safetensors_name)
|
||||
RepoEntry = Tuple[str, str]
|
||||
|
||||
|
||||
def load_repo_ids(path: str, max_models: int | None = None) -> List[RepoEntry]:
|
||||
"""Read ``repo_id, safetensors_name`` pairs from *path*.
|
||||
|
||||
Format (one per line, blanks and ``#`` comments ignored)::
|
||||
|
||||
user/repo-name, lora_zimage_turbo_myjs_alpha01.safetensors
|
||||
|
||||
Returns a list of ``(repo_id, safetensors_name)`` tuples.
|
||||
"""
|
||||
path = os.path.expanduser(path)
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Models file not found: {path}")
|
||||
|
||||
repos: List[str] = []
|
||||
entries: List[RepoEntry] = []
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
for raw_line in fh:
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
repos.append(line)
|
||||
|
||||
# Split on the first comma
|
||||
if "," not in line:
|
||||
logger.warning("Skipping malformed line (no comma): %s", raw_line.rstrip())
|
||||
continue
|
||||
|
||||
repo_id, safetensors_name = [part.strip() for part in line.split(",", 1)]
|
||||
if not repo_id or not safetensors_name:
|
||||
logger.warning("Skipping malformed line (empty fields): %s", raw_line.rstrip())
|
||||
continue
|
||||
if not safetensors_name.lower().endswith(".safetensors"):
|
||||
logger.warning(
|
||||
"Skipping line — safetensors_name doesn't end with .safetensors: %s",
|
||||
raw_line.rstrip(),
|
||||
)
|
||||
continue
|
||||
|
||||
entries.append((repo_id, safetensors_name))
|
||||
|
||||
if max_models is not None and max_models > 0:
|
||||
repos = repos[:max_models]
|
||||
entries = entries[:max_models]
|
||||
|
||||
logger.info("Loaded %d HF repo IDs from %s", len(repos), path)
|
||||
return repos
|
||||
logger.info("Loaded %d HF repo entries from %s", len(entries), path)
|
||||
return entries
|
||||
|
||||
|
||||
def sanitize_repo_id(repo_id: str) -> str:
|
||||
@@ -47,9 +83,9 @@ def build_model_dir(output_dir: str, repo_id: str) -> str:
|
||||
return os.path.join(output_dir, "models", sanitize_repo_id(repo_id))
|
||||
|
||||
|
||||
def build_model_path(model_dir: str) -> str:
|
||||
"""Return a synthetic model file path (no real file will exist)."""
|
||||
return os.path.join(model_dir, "model.safetensors")
|
||||
def build_model_path(model_dir: str, safetensors_name: str) -> str:
|
||||
"""Return the model file path using the real safetensors filename."""
|
||||
return os.path.join(model_dir, safetensors_name)
|
||||
|
||||
|
||||
def build_metadata_path(model_path: str) -> str:
|
||||
@@ -58,8 +94,8 @@ def build_metadata_path(model_path: str) -> str:
|
||||
This MUST match the convention used by ``MetadataManager`` /
|
||||
``apply_metadata_updates``, which derives the sidecar path via
|
||||
``os.path.splitext(model_path)[0] + '.metadata.json'``.
|
||||
For a model file ``model.safetensors`` the sidecar is
|
||||
``model.metadata.json`` — *not* ``model.safetensors.metadata.json``.
|
||||
For a model file ``lora_x.safetensors`` the sidecar is
|
||||
``lora_x.metadata.json`` — *not* ``lora_x.safetensors.metadata.json``.
|
||||
"""
|
||||
return f"{os.path.splitext(model_path)[0]}.metadata.json"
|
||||
|
||||
@@ -67,23 +103,32 @@ def build_metadata_path(model_path: str) -> str:
|
||||
def create_initial_metadata(
|
||||
output_dir: str,
|
||||
repo_id: str,
|
||||
safetensors_name: str,
|
||||
) -> str:
|
||||
"""Write a minimal ``.metadata.json`` for *repo_id*.
|
||||
"""Write a minimal ``.metadata.json`` for *repo_id* + *safetensors_name*.
|
||||
|
||||
Args:
|
||||
output_dir: Root output directory.
|
||||
repo_id: HuggingFace repo identifier (``user/repo``).
|
||||
safetensors_name: The specific model file name (e.g.
|
||||
``lora_zimage_turbo_myjs_alpha01.safetensors``).
|
||||
|
||||
Returns the **model path** (the ``.safetensors`` path whose sidecar was
|
||||
written). The caller passes this path to ``AgentService.execute_skill``.
|
||||
The basename (filename without extension) will match the real model file,
|
||||
so ``extract_relevant_section`` can reliably match against the README.
|
||||
"""
|
||||
model_dir = build_model_dir(output_dir, repo_id)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
model_path = build_model_path(model_dir)
|
||||
model_path = build_model_path(model_dir, safetensors_name)
|
||||
metadata_path = build_metadata_path(model_path)
|
||||
|
||||
hf_url = f"https://huggingface.co/{repo_id}"
|
||||
file_name = repo_id.split("/")[-1]
|
||||
file_name = safetensors_name
|
||||
|
||||
metadata: Dict = {
|
||||
metadata: Dict[str, Any] = {
|
||||
"file_name": file_name,
|
||||
"model_name": file_name,
|
||||
"model_name": safetensors_name,
|
||||
"file_path": model_path.replace(os.sep, "/"),
|
||||
"size": 0,
|
||||
"modified": 0,
|
||||
@@ -117,32 +162,41 @@ def create_initial_metadata(
|
||||
|
||||
|
||||
def create_all_initial_metadata(
|
||||
repos: List[str],
|
||||
entries: List[RepoEntry],
|
||||
output_dir: str,
|
||||
*,
|
||||
skip_existing: bool = True,
|
||||
) -> List[str]:
|
||||
"""Create initial metadata for every repo in *repos*.
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Create initial metadata for every repo entry.
|
||||
|
||||
Returns a list of model paths in the same order as *repos*.
|
||||
``skip_existing=True`` skips repos whose metadata already exists,
|
||||
allowing safe re-run.
|
||||
Args:
|
||||
entries: List of ``(repo_id, safetensors_name)`` tuples.
|
||||
output_dir: Root output directory.
|
||||
skip_existing: If True, skip repos whose metadata already exists.
|
||||
|
||||
Returns:
|
||||
A tuple ``(model_paths, repo_ids)`` — two parallel lists in the same
|
||||
order as *entries*. This keeps downstream code (enrichment runner,
|
||||
evaluation engine) unchanged.
|
||||
"""
|
||||
model_paths: List[str] = []
|
||||
for repo_id in repos:
|
||||
repo_ids: List[str] = []
|
||||
for repo_id, safetensors_name in entries:
|
||||
model_dir = build_model_dir(output_dir, repo_id)
|
||||
model_path = build_model_path(model_dir)
|
||||
model_path = build_model_path(model_dir, safetensors_name)
|
||||
metadata_path = build_metadata_path(model_path)
|
||||
|
||||
if skip_existing and os.path.exists(metadata_path):
|
||||
model_paths.append(model_path)
|
||||
repo_ids.append(repo_id)
|
||||
continue
|
||||
|
||||
model_paths.append(create_initial_metadata(output_dir, repo_id))
|
||||
model_paths.append(create_initial_metadata(output_dir, repo_id, safetensors_name))
|
||||
repo_ids.append(repo_id)
|
||||
|
||||
logger.info(
|
||||
"Constructed initial metadata for %d/%d repos",
|
||||
len(model_paths),
|
||||
len(repos),
|
||||
len(entries),
|
||||
)
|
||||
return model_paths
|
||||
return model_paths, repo_ids
|
||||
|
||||
467
tests/enrich_hf_validation/preprocessing_auditor.py
Normal file
467
tests/enrich_hf_validation/preprocessing_auditor.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""Preprocessing audit for the HF metadata enrichment validation pipeline.
|
||||
|
||||
Phase 1.5 — runs between Phase 1 (metadata creation) and Phase 2 (enrichment).
|
||||
|
||||
Audits the README preprocessing pipeline (section extraction + cleaning)
|
||||
for each repo in the dataset, capturing intermediate outputs so we can
|
||||
distinguish between:
|
||||
|
||||
(A) Preprocessing failed → LLM never saw the right content
|
||||
(B) Preprocessing succeeded → LLM/prompt needs improvement
|
||||
|
||||
This prevents wasted effort optimizing prompts when the actual problem is
|
||||
that ``extract_relevant_section`` or ``clean_readme_for_llm`` removed or
|
||||
misaligned the content the LLM needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audit record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditRecord:
|
||||
"""Preprocessing audit for a single repo entry."""
|
||||
|
||||
# Identity
|
||||
repo_id: str
|
||||
safetensors_name: str
|
||||
basename: str # filename without .safetensors
|
||||
|
||||
# Raw README stats
|
||||
raw_readme_length: int
|
||||
raw_readme_line_count: int
|
||||
has_yaml_frontmatter: bool
|
||||
yaml_has_base_model: bool
|
||||
yaml_has_tags: bool
|
||||
|
||||
# Section extraction
|
||||
section_extraction_activated: bool # output < 95% of input length
|
||||
section_length: int
|
||||
section_line_count: int
|
||||
basename_in_section: bool # basename appears in extracted section text
|
||||
|
||||
# Cleaning
|
||||
cleaned_length: int
|
||||
cleaned_line_count: int
|
||||
compression_pct: float # (1 - cleaned/raw) * 100
|
||||
|
||||
# Widget section (stripped by _strip_widget_section)
|
||||
widget_section_found: bool
|
||||
widget_section_length: int
|
||||
|
||||
# Flags (list of anomaly descriptions)
|
||||
flags: List[str] = field(default_factory=list)
|
||||
|
||||
# Local file path to the saved raw README (for cross-reference)
|
||||
readme_file: str = ""
|
||||
|
||||
# Staged intermediate output for report detail
|
||||
raw_readme_preview: str = "" # first 200 chars
|
||||
section_preview: str = "" # first 300 chars
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HF_RAW_URL = "https://huggingface.co/{repo_id}/raw/main/README.md"
|
||||
|
||||
# Thresholds for flagging
|
||||
_SECTION_ACTIVATION_RATIO = 0.95
|
||||
_MIN_CLEANED_LENGTH = 100
|
||||
_MAX_COMPRESSION_PCT = 99.0
|
||||
_MIN_SECTION_LINES = 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module loader — bypasses parent-package __init__ that imports ComfyUI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_readme_processor_module = None
|
||||
|
||||
|
||||
def _load_readme_processor():
|
||||
"""Import ``readme_processor`` without triggering ``folder_paths`` import.
|
||||
|
||||
The normal import path (``py.services.agent.skills.enrich_hf_metadata.
|
||||
readme_processor``) triggers ``py.services.agent.__init__`` which
|
||||
imports ``agent_service.py`` → ``py/config.py`` → ComfyUI's
|
||||
``folder_paths``, which is not available in standalone mode.
|
||||
"""
|
||||
global _readme_processor_module
|
||||
if _readme_processor_module is not None:
|
||||
return _readme_processor_module
|
||||
|
||||
import importlib.util
|
||||
|
||||
_RP_PATH = os.path.join(
|
||||
os.path.dirname(__file__), # tests/enrich_hf_validation/
|
||||
"..", "..",
|
||||
"py", "services", "agent", "skills", "enrich_hf_metadata",
|
||||
"readme_processor.py",
|
||||
)
|
||||
rp_path = os.path.normpath(_RP_PATH)
|
||||
if not os.path.exists(rp_path):
|
||||
logger.error("readme_processor.py not found at %s", rp_path)
|
||||
return None
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"readme_processor", rp_path,
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error("Could not create spec for readme_processor.py")
|
||||
return None
|
||||
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
try:
|
||||
spec.loader.exec_module(mod)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load readme_processor.py: %s", exc)
|
||||
return None
|
||||
|
||||
_readme_processor_module = mod
|
||||
return mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HF README fetcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _fetch_readme(repo_id: str, session: aiohttp.ClientSession) -> str:
|
||||
"""Fetch the raw README.md from HuggingFace."""
|
||||
url = _HF_RAW_URL.format(repo_id=repo_id)
|
||||
try:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.text()
|
||||
logger.warning("Failed to fetch README for %s: HTTP %d", repo_id, resp.status)
|
||||
return ""
|
||||
except (asyncio.TimeoutError, aiohttp.ClientError) as exc:
|
||||
logger.warning("Failed to fetch README for %s: %s", repo_id, exc)
|
||||
return ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Analysis helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _has_yaml_frontmatter(text: str) -> bool:
|
||||
return bool(text.strip().startswith("---"))
|
||||
|
||||
|
||||
def _extract_yaml_field(text: str, field: str) -> bool:
|
||||
"""Check if the given YAML field exists in the frontmatter."""
|
||||
lines = text.split("\n")
|
||||
if not lines or not lines[0].strip().startswith("---"):
|
||||
return False
|
||||
end = 1
|
||||
while end < len(lines):
|
||||
if lines[end].strip().startswith("---"):
|
||||
break
|
||||
end += 1
|
||||
if end >= len(lines):
|
||||
return False
|
||||
frontmatter = "\n".join(lines[1:end])
|
||||
pattern = rf"^{field}:"
|
||||
return bool(re.search(pattern, frontmatter, re.MULTILINE))
|
||||
|
||||
|
||||
def _find_widget_section_length(text: str) -> int:
|
||||
"""Find the ``widget:`` YAML section and return its length (0 if none)."""
|
||||
if not _has_yaml_frontmatter(text):
|
||||
return 0
|
||||
frontmatter_end = text.find("---", 3)
|
||||
if frontmatter_end == -1:
|
||||
return 0
|
||||
frontmatter = text[3:frontmatter_end]
|
||||
|
||||
# Match widget: through to the next top-level key or frontmatter end
|
||||
m = re.search(r"\nwidget:", frontmatter)
|
||||
if not m:
|
||||
return 0
|
||||
# Length from widget: to end of frontmatter (the next \n\w+: or \n---)
|
||||
return len(frontmatter[m.start():])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core auditor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_audit(
|
||||
entries: List[Tuple[str, str]],
|
||||
*,
|
||||
concurrency: int = 10,
|
||||
readmes_dir: str | None = None,
|
||||
) -> Tuple[List[AuditRecord], Dict[str, Any]]:
|
||||
"""Run the preprocessing audit over all repo entries.
|
||||
|
||||
Args:
|
||||
entries: List of ``(repo_id, safetensors_name)``.
|
||||
concurrency: Max parallel fetches to HuggingFace.
|
||||
readmes_dir: If set, saves each fetched README as
|
||||
``{sanitized_repo_id}.md`` in this directory for offline
|
||||
cross-reference against audit results.
|
||||
|
||||
Returns:
|
||||
Tuple of ``(records, summary)`` where *summary* is a dict with
|
||||
aggregate statistics.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
records: List[AuditRecord] = []
|
||||
flag_counter: Dict[str, int] = {}
|
||||
|
||||
if readmes_dir:
|
||||
os.makedirs(readmes_dir, exist_ok=True)
|
||||
|
||||
connector = aiohttp.TCPConnector(limit=concurrency)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
tasks = [_audit_one(entry, session, semaphore, readmes_dir=readmes_dir) for entry in entries]
|
||||
gathered = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for entry, result in zip(entries, gathered):
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Audit failed for %s: %s", entry[0], result)
|
||||
records.append(
|
||||
AuditRecord(
|
||||
repo_id=entry[0],
|
||||
safetensors_name=entry[1],
|
||||
basename=os.path.splitext(entry[1])[0],
|
||||
raw_readme_length=0,
|
||||
raw_readme_line_count=0,
|
||||
has_yaml_frontmatter=False,
|
||||
yaml_has_base_model=False,
|
||||
yaml_has_tags=False,
|
||||
section_extraction_activated=False,
|
||||
section_length=0,
|
||||
section_line_count=0,
|
||||
basename_in_section=False,
|
||||
cleaned_length=0,
|
||||
cleaned_line_count=0,
|
||||
compression_pct=0.0,
|
||||
widget_section_found=False,
|
||||
widget_section_length=0,
|
||||
readme_file="",
|
||||
flags=[f"Audit exception: {result}"],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# The continue above ensures result is AuditRecord here
|
||||
assert isinstance(result, AuditRecord)
|
||||
records.append(result)
|
||||
for flag in result.flags:
|
||||
flag_counter[flag] = flag_counter.get(flag, 0) + 1
|
||||
|
||||
summary = _build_summary(records, flag_counter)
|
||||
return records, summary
|
||||
|
||||
|
||||
def _sanitize_repo_id(repo_id: str) -> str:
|
||||
"""Turn ``user/repo-name`` into a safe filename."""
|
||||
return repo_id.replace("/", "__").replace(".", "_")
|
||||
|
||||
|
||||
async def _audit_one(
|
||||
entry: Tuple[str, str],
|
||||
session: aiohttp.ClientSession,
|
||||
semaphore: asyncio.Semaphore,
|
||||
*,
|
||||
readmes_dir: str | None = None,
|
||||
) -> AuditRecord:
|
||||
"""Audit a single repo entry."""
|
||||
repo_id, safetensors_name = entry
|
||||
basename = os.path.splitext(safetensors_name)[0]
|
||||
|
||||
async with semaphore:
|
||||
# Import production preprocessing functions.
|
||||
# Use importlib to bypass py.services.agent.__init__ which triggers
|
||||
# ComfyUI's folder_paths module (not available in standalone mode).
|
||||
_rp = _load_readme_processor()
|
||||
if _rp is None:
|
||||
return AuditRecord(
|
||||
repo_id=repo_id,
|
||||
safetensors_name=safetensors_name,
|
||||
basename=basename,
|
||||
raw_readme_length=0, raw_readme_line_count=0,
|
||||
has_yaml_frontmatter=False, yaml_has_base_model=False, yaml_has_tags=False,
|
||||
readme_file="",
|
||||
section_extraction_activated=False, section_length=0, section_line_count=0,
|
||||
basename_in_section=False, cleaned_length=0, cleaned_line_count=0,
|
||||
compression_pct=0.0, widget_section_found=False, widget_section_length=0,
|
||||
flags=["IMPORT_FAILED"],
|
||||
)
|
||||
clean_readme_for_llm = _rp.clean_readme_for_llm
|
||||
extract_relevant_section = _rp.extract_relevant_section
|
||||
|
||||
# Step 1: Fetch the raw README
|
||||
raw_text = await _fetch_readme(repo_id, session)
|
||||
if not raw_text:
|
||||
return AuditRecord(
|
||||
repo_id=repo_id,
|
||||
safetensors_name=safetensors_name,
|
||||
basename=basename,
|
||||
raw_readme_length=0,
|
||||
raw_readme_line_count=0,
|
||||
has_yaml_frontmatter=False,
|
||||
yaml_has_base_model=False,
|
||||
yaml_has_tags=False,
|
||||
section_extraction_activated=False,
|
||||
section_length=0,
|
||||
section_line_count=0,
|
||||
basename_in_section=False,
|
||||
readme_file="",
|
||||
cleaned_length=0,
|
||||
cleaned_line_count=0,
|
||||
compression_pct=0.0,
|
||||
widget_section_found=False,
|
||||
widget_section_length=0,
|
||||
flags=["README_FETCH_FAILED"],
|
||||
)
|
||||
|
||||
# Save the raw README to disk for offline cross-reference
|
||||
readme_path = ""
|
||||
if readmes_dir:
|
||||
safe_name = _sanitize_repo_id(repo_id)
|
||||
readme_path = os.path.join(readmes_dir, f"{safe_name}.md")
|
||||
try:
|
||||
with open(readme_path, "w", encoding="utf-8") as fh:
|
||||
fh.write(raw_text)
|
||||
except OSError as exc:
|
||||
logger.warning("Failed to save README for %s: %s", repo_id, exc)
|
||||
readme_path = ""
|
||||
|
||||
raw_lines = raw_text.split("\n")
|
||||
raw_len = len(raw_text)
|
||||
raw_line_count = len(raw_lines)
|
||||
|
||||
# Step 2: Analyze raw README
|
||||
yaml_fm = _has_yaml_frontmatter(raw_text)
|
||||
yaml_has_bm = _extract_yaml_field(raw_text, "base_model") if yaml_fm else False
|
||||
yaml_has_tg = _extract_yaml_field(raw_text, "tags") if yaml_fm else False
|
||||
widget_len = _find_widget_section_length(raw_text)
|
||||
|
||||
# Step 3: Section extraction
|
||||
section = extract_relevant_section(raw_text, basename)
|
||||
section_len = len(section)
|
||||
section_line_count = len(section.split("\n"))
|
||||
section_activated = section_len < raw_len * _SECTION_ACTIVATION_RATIO
|
||||
basename_in_sec = basename.lower() in section.lower()
|
||||
|
||||
# Step 4: Cleaning for LLM
|
||||
cleaned = clean_readme_for_llm(section)
|
||||
cleaned_len = len(cleaned)
|
||||
cleaned_line_count = len(cleaned.split("\n"))
|
||||
compression_pct = round((1 - cleaned_len / raw_len) * 100, 1) if raw_len else 0.0
|
||||
|
||||
# Step 5: Flag anomalies
|
||||
flags: List[str] = []
|
||||
if not raw_text.strip():
|
||||
flags.append("README_EMPTY")
|
||||
if not yaml_fm:
|
||||
flags.append("NO_YAML_FRONTMATTER")
|
||||
if not section_activated:
|
||||
# Check if basename is extremely short/generic (likely synthetic)
|
||||
if len(basename) <= 5:
|
||||
flags.append("BASENAME_TOO_SHORT_SECTION_NOT_EXPECTED")
|
||||
else:
|
||||
flags.append("SECTION_EXTRACTION_NOT_ACTIVATED")
|
||||
elif not basename_in_sec:
|
||||
flags.append("BASENAME_NOT_IN_EXTRACTED_SECTION")
|
||||
if widget_len == 0:
|
||||
# Not necessarily a problem — many repos lack a widget section
|
||||
pass
|
||||
if cleaned_len < _MIN_CLEANED_LENGTH:
|
||||
flags.append("CLEANED_README_TOO_SHORT")
|
||||
if compression_pct > _MAX_COMPRESSION_PCT:
|
||||
flags.append("EXTREME_COMPRESSION")
|
||||
if section_activated and section_line_count < _MIN_SECTION_LINES:
|
||||
flags.append("SECTION_TOO_SMALL")
|
||||
|
||||
return AuditRecord(
|
||||
repo_id=repo_id,
|
||||
safetensors_name=safetensors_name,
|
||||
basename=basename,
|
||||
raw_readme_length=raw_len,
|
||||
raw_readme_line_count=raw_line_count,
|
||||
has_yaml_frontmatter=yaml_fm,
|
||||
yaml_has_base_model=yaml_has_bm,
|
||||
yaml_has_tags=yaml_has_tg,
|
||||
section_extraction_activated=section_activated,
|
||||
section_length=section_len,
|
||||
section_line_count=section_line_count,
|
||||
basename_in_section=basename_in_sec,
|
||||
cleaned_length=cleaned_len,
|
||||
cleaned_line_count=cleaned_line_count,
|
||||
compression_pct=compression_pct,
|
||||
widget_section_found=widget_len > 0,
|
||||
widget_section_length=widget_len,
|
||||
readme_file=readme_path,
|
||||
flags=flags,
|
||||
raw_readme_preview=raw_text[:200],
|
||||
section_preview=section[:300],
|
||||
)
|
||||
|
||||
|
||||
def _build_summary(
|
||||
records: List[AuditRecord],
|
||||
flag_counter: Dict[str, int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate audit statistics."""
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
return {"error": "no records", "model_count": 0}
|
||||
|
||||
activated = sum(1 for r in records if r.section_extraction_activated)
|
||||
basename_hit = sum(1 for r in records if r.basename_in_section)
|
||||
with_yaml = sum(1 for r in records if r.has_yaml_frontmatter)
|
||||
with_widget = sum(1 for r in records if r.widget_section_found)
|
||||
fetch_failed = sum(1 for r in records if "README_FETCH_FAILED" in r.flags)
|
||||
|
||||
avg_compression = round(
|
||||
sum(r.compression_pct for r in records if r.raw_readme_length > 0) / max(n - fetch_failed, 1),
|
||||
1,
|
||||
)
|
||||
avg_cleaned = round(
|
||||
sum(r.cleaned_length for r in records if r.raw_readme_length > 0) / max(n - fetch_failed, 1),
|
||||
)
|
||||
|
||||
top_flags = sorted(flag_counter.items(), key=lambda x: -x[1])[:10]
|
||||
|
||||
return {
|
||||
"model_count": n,
|
||||
"fetch_failed_count": fetch_failed,
|
||||
"section_extraction_activated": activated,
|
||||
"section_extraction_pct": round(activated / max(n - fetch_failed, 1) * 100, 1),
|
||||
"basename_in_section": basename_hit,
|
||||
"basename_in_section_pct": round(basename_hit / max(n - fetch_failed, 1) * 100, 1),
|
||||
"with_yaml_frontmatter": with_yaml,
|
||||
"with_yaml_frontmatter_pct": round(with_yaml / max(n - fetch_failed, 1) * 100, 1),
|
||||
"with_widget_section": with_widget,
|
||||
"avg_compression_pct": avg_compression,
|
||||
"avg_cleaned_length": avg_cleaned,
|
||||
"top_flags": top_flags,
|
||||
}
|
||||
|
||||
|
||||
def audit_records_to_serializable(records: List[AuditRecord]) -> List[Dict[str, Any]]:
|
||||
"""Convert AuditRecord dataclasses to plain dicts for JSON serialization."""
|
||||
return [asdict(r) for r in records]
|
||||
@@ -15,6 +15,7 @@ import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .config import SUPPORTED_BASE_MODELS
|
||||
from .evaluation_engine import ScoreRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,33 +57,12 @@ def generate_optimisation_suggestions(
|
||||
for s in scores
|
||||
if s["raw_values"]["base_model"]
|
||||
and s["raw_values"]["base_model"] != "Unknown"
|
||||
and s["raw_values"]["base_model"] not in {
|
||||
"SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper",
|
||||
"SD 2.0", "SD 2.1", "SD 3", "SD 3.5", "SD 3.5 Medium",
|
||||
"SD 3.5 Large", "SD 3.5 Large Turbo",
|
||||
"SDXL 1.0", "SDXL Lightning", "SDXL Hyper",
|
||||
"Flux.1 D", "Flux.1 S", "Flux.1 Krea", "Flux.1 Kontext",
|
||||
"Flux.2 D", "Flux.2 Klein 9B", "Flux.2 Klein 9B-base",
|
||||
"Flux.2 Klein 4B", "Flux.2 Klein 4B-base",
|
||||
"AuraFlow", "Chroma", "PixArt a", "PixArt E",
|
||||
"Hunyuan 1", "Lumina", "Kolors",
|
||||
"NoobAI", "Illustrious", "Pony", "Pony V7",
|
||||
"HiDream", "Qwen", "ZImageTurbo", "ZImageBase",
|
||||
"SVD", "LTXV", "LTXV2", "LTXV 2.3",
|
||||
"CogVideoX", "Mochi",
|
||||
"Wan Video", "Wan Video 1.3B t2v", "Wan Video 14B t2v",
|
||||
"Wan Video 14B i2v 480p", "Wan Video 14B i2v 720p",
|
||||
"Wan Video 2.2 TI2V-5B", "Wan Video 2.2 T2V-A14B",
|
||||
"Wan Video 2.2 I2V-A14B",
|
||||
"Wan Video 2.5 T2V", "Wan Video 2.5 I2V",
|
||||
"Hunyuan Video", "Anima", "Ernie", "Ernie Turbo",
|
||||
"Nucleus", "Krea 2",
|
||||
}
|
||||
and s["raw_values"]["base_model"] not in set(SUPPORTED_BASE_MODELS)
|
||||
)
|
||||
if bm_invalid > 5:
|
||||
suggestions.append(
|
||||
"- **base_model 含非标准值 ({} 个)**: LLM 输出了未在 `SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS` "
|
||||
"中的 base model 名称。建议在 prompt 中强调 \"Use EXACTLY one name from the list\" 并在 "
|
||||
"- **base_model 含非标准值 ({} 个)**: LLM 输出了未在当前生产系统的 base model 列表 "
|
||||
"中的名称。建议在 prompt 中强调 \"Use EXACTLY one name from the list\" 并在 "
|
||||
"`PostProcessor` 中加一层验证过滤,非标准值直接丢弃。".format(bm_invalid)
|
||||
)
|
||||
|
||||
@@ -139,7 +119,8 @@ def generate_optimisation_suggestions(
|
||||
if ut and ut.get("empty_rate_pct", 0) > 70:
|
||||
suggestions.append(
|
||||
"- **usage_tips 空置率极高 ({:.0f}%)**: 这是预期行为。HF 模型卡通常不包含 LoRA "
|
||||
"强度/CLIP skip 等结构化参数。当前提取策略已合理。若需要可用数据," "可以考虑使用模型类型的通用默认值。".format(
|
||||
"强度/CLIP skip 等结构化参数。当前提取策略已合理。若需要可用数据,"
|
||||
"可以考虑使用模型类型的通用默认值。".format(
|
||||
ut.get("empty_rate_pct", 0)
|
||||
)
|
||||
)
|
||||
@@ -164,8 +145,20 @@ def generate_markdown_report(
|
||||
scores: List[ScoreRecord],
|
||||
output_dir: str,
|
||||
duration_summary: Dict[str, Any] | None = None,
|
||||
*,
|
||||
audit_summary: Dict[str, Any] | None = None,
|
||||
config_warnings: List[str] | None = None,
|
||||
) -> str:
|
||||
"""Write ``report.md`` and return its content."""
|
||||
"""Write ``report.md`` and return its content.
|
||||
|
||||
Args:
|
||||
agg: Aggregate evaluation scores.
|
||||
scores: Per-model evaluation records.
|
||||
output_dir: Output directory for the report file.
|
||||
duration_summary: Optional timing statistics.
|
||||
audit_summary: Optional preprocessing audit summary (Phase 1.5).
|
||||
config_warnings: Optional LLM config consistency warnings.
|
||||
"""
|
||||
lines: List[str] = []
|
||||
def wl(text: str = "") -> None:
|
||||
lines.append(text)
|
||||
@@ -178,6 +171,60 @@ def generate_markdown_report(
|
||||
wl(f"Failures: **{agg.get('fail_count', 0)}**")
|
||||
wl()
|
||||
|
||||
# ---- Preprocessing Audit Section ----
|
||||
if audit_summary and audit_summary.get("model_count", 0) > 0:
|
||||
wl("## Preprocessing Audit")
|
||||
wl()
|
||||
wl(f"| Metric | Value |")
|
||||
wl(f"|--------|-------|")
|
||||
wl(f"| Models audited | {audit_summary.get('model_count', 0)} |")
|
||||
wl(f"| README fetch failed | {audit_summary.get('fetch_failed_count', 0)} |")
|
||||
wl(f"| Section extraction activated | {_fmt_pct(audit_summary.get('section_extraction_pct', 0))} |")
|
||||
wl(f"| Basename found in section | {_fmt_pct(audit_summary.get('basename_in_section_pct', 0))} |")
|
||||
wl(f"| Has YAML frontmatter | {_fmt_pct(audit_summary.get('with_yaml_frontmatter_pct', 0))} |")
|
||||
wl(f"| Has YAML widget section | {_fmt_pct(audit_summary.get('with_widget_section', 0))} |")
|
||||
wl(f"| Avg README compression | {audit_summary.get('avg_compression_pct', 0)}% |")
|
||||
wl(f"| Avg cleaned length | {audit_summary.get('avg_cleaned_length', 0)} chars |")
|
||||
wl()
|
||||
|
||||
if audit_summary.get("top_flags"):
|
||||
wl("### Audit Flags (most frequent)")
|
||||
wl()
|
||||
for flag, count in audit_summary["top_flags"]:
|
||||
wl(f"- **{flag}**: {count}x")
|
||||
wl()
|
||||
|
||||
wl("**Interpretation:**")
|
||||
wl()
|
||||
act_pct = audit_summary.get("section_extraction_pct", 0)
|
||||
if act_pct < 50:
|
||||
wl(
|
||||
"- ⚠️ Section extraction activated for fewer than 50% of repos. "
|
||||
"This may indicate the basename doesn't match README content, or the "
|
||||
"repos are mostly single-model (where full README is expected)."
|
||||
)
|
||||
else:
|
||||
wl(
|
||||
"- ✅ Section extraction is working for most repos — the LLM is "
|
||||
"receiving focused README sections."
|
||||
)
|
||||
|
||||
if audit_summary.get("basename_in_section_pct", 100) < 80:
|
||||
wl(
|
||||
"- ⚠️ The safetensors basename was NOT found in the extracted section "
|
||||
"for many repos. This could mean the section extraction matched the wrong "
|
||||
"section, or the README doesn't explicitly reference the filename."
|
||||
)
|
||||
wl()
|
||||
|
||||
# ---- Config warnings ----
|
||||
if config_warnings:
|
||||
wl("## ⚠️ Configuration Warnings")
|
||||
wl()
|
||||
for w in config_warnings:
|
||||
wl(f"- {w}")
|
||||
wl()
|
||||
|
||||
# ---- Duration ----
|
||||
if duration_summary:
|
||||
wl("## Timing")
|
||||
@@ -307,8 +354,21 @@ def save_json_report(
|
||||
enrichment_results: List[Dict[str, Any]],
|
||||
output_dir: str,
|
||||
duration_summary: Dict[str, Any] | None = None,
|
||||
*,
|
||||
audit_summary: Dict[str, Any] | None = None,
|
||||
config_warnings: List[str] | None = None,
|
||||
) -> str:
|
||||
"""Write ``report.json`` and return the path."""
|
||||
"""Write ``report.json`` and return the path.
|
||||
|
||||
Args:
|
||||
agg: Aggregate evaluation scores.
|
||||
scores: Per-model evaluation records.
|
||||
enrichment_results: Raw enrichment phase results.
|
||||
output_dir: Output directory.
|
||||
duration_summary: Optional timing statistics.
|
||||
audit_summary: Optional preprocessing audit summary.
|
||||
config_warnings: Optional LLM config consistency warnings.
|
||||
"""
|
||||
report: Dict[str, Any] = {
|
||||
"metadata": {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
@@ -319,6 +379,11 @@ def save_json_report(
|
||||
"per_model_scores": scores,
|
||||
"enrichment_results": enrichment_results,
|
||||
}
|
||||
if audit_summary:
|
||||
report["preprocessing_audit"] = audit_summary
|
||||
if config_warnings:
|
||||
report["config_warnings"] = config_warnings
|
||||
|
||||
path = os.path.join(output_dir, "report.json")
|
||||
with open(path, "w", encoding="utf-8") as fh:
|
||||
json.dump(report, fh, indent=2, ensure_ascii=False)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
Usage::
|
||||
|
||||
# Full run (100 models, serial, ~1-2 h)
|
||||
# Full run (44 models, serial, ~1-2 h)
|
||||
python -m tests.enrich_hf_validation.run_validation \\
|
||||
--output /tmp/hf_enrich_validation
|
||||
|
||||
@@ -13,6 +13,9 @@ Usage::
|
||||
# Resume from a previous partial run
|
||||
python -m tests.enrich_hf_validation.run_validation --resume
|
||||
|
||||
# Audit preprocessing only (no LLM calls, fast)
|
||||
python -m tests.enrich_hf_validation.run_validation --audit-only
|
||||
|
||||
# Custom settings file
|
||||
python -m tests.enrich_hf_validation.run_validation \\
|
||||
--settings /custom/path/settings.json
|
||||
@@ -27,7 +30,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
# Ensure the project root is on sys.path so that ``from py import ...`` works.
|
||||
_PROJECT_ROOT = os.path.normpath(
|
||||
@@ -36,8 +39,18 @@ _PROJECT_ROOT = os.path.normpath(
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from tests.enrich_hf_validation.config import load_settings
|
||||
# Add ComfyUI root to sys.path so ``folder_paths`` can be imported.
|
||||
# Project layout: ComfyUI/custom_nodes/ComfyUI-Lora-Manager/
|
||||
_COMFYUI_ROOT = os.path.normpath(os.path.join(_PROJECT_ROOT, "..", ".."))
|
||||
if _COMFYUI_ROOT not in sys.path:
|
||||
sys.path.insert(0, _COMFYUI_ROOT)
|
||||
|
||||
from tests.enrich_hf_validation.config import (
|
||||
init_supported_base_models,
|
||||
load_settings,
|
||||
)
|
||||
from tests.enrich_hf_validation.metadata_constructor import (
|
||||
RepoEntry,
|
||||
create_all_initial_metadata,
|
||||
load_repo_ids,
|
||||
)
|
||||
@@ -46,6 +59,10 @@ from tests.enrich_hf_validation.evaluation_engine import (
|
||||
aggregate_scores,
|
||||
evaluate_batch,
|
||||
)
|
||||
from tests.enrich_hf_validation.preprocessing_auditor import (
|
||||
audit_records_to_serializable,
|
||||
run_audit,
|
||||
)
|
||||
from tests.enrich_hf_validation.report_generator import (
|
||||
generate_markdown_report,
|
||||
save_json_report,
|
||||
@@ -70,8 +87,8 @@ def _parse_args(argv: List[str]) -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
default="~/Documents/hf_lora_models.txt",
|
||||
help="Path to the HF repo ID list (one per line)",
|
||||
default="~/Documents/hf_lora_models_with_safetensors.txt",
|
||||
help="Path to the HF repo entries file (format: repo_id, model_name.safetensors per line)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--settings",
|
||||
@@ -99,6 +116,11 @@ def _parse_args(argv: List[str]) -> argparse.Namespace:
|
||||
action="store_true",
|
||||
help="Skip enrichment phase (evaluate existing metadata only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audit-only",
|
||||
action="store_true",
|
||||
help="Run preprocessing audit only (no enrichment, no evaluation)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
@@ -125,6 +147,117 @@ def _phase_header(label: str) -> None:
|
||||
print(sep, file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read back LLM config after enrichment (for consistency reporting)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_actual_llm_config() -> Dict[str, str]:
|
||||
"""Read what LLMService is actually using, if initialized.
|
||||
|
||||
Only meaningful when called AFTER enrichment has started (i.e. after
|
||||
``AgentService.get_instance()`` has been called).
|
||||
"""
|
||||
try:
|
||||
from py.services.llm_service import LLMService
|
||||
|
||||
instance = LLMService._instance
|
||||
if instance is None:
|
||||
return {"status": "not initialized"}
|
||||
cfg = instance._get_config()
|
||||
return {
|
||||
"provider": cfg.get("provider", ""),
|
||||
"model": cfg.get("model", ""),
|
||||
"api_base": cfg.get("api_base", ""),
|
||||
}
|
||||
except Exception as exc:
|
||||
return {"status": f"error: {exc}"}
|
||||
|
||||
|
||||
def _compare_llm_config(
|
||||
pipeline_cfg: Dict[str, Any],
|
||||
actual_cfg: Dict[str, str],
|
||||
) -> List[str]:
|
||||
"""Compare pipeline-loaded vs LLMService-used config.
|
||||
|
||||
Returns warning messages if they differ.
|
||||
"""
|
||||
warnings: List[str] = []
|
||||
if not actual_cfg or actual_cfg.get("status", "") == "not initialized":
|
||||
warnings.append(
|
||||
"LLMService was not initialized during this run — cannot verify "
|
||||
"config consistency."
|
||||
)
|
||||
return warnings
|
||||
|
||||
field_map = [
|
||||
("llm_provider", "provider"),
|
||||
("llm_model", "model"),
|
||||
("llm_api_base", "api_base"),
|
||||
]
|
||||
for pipeline_key, llm_key in field_map:
|
||||
pv = (pipeline_cfg.get(pipeline_key) or "").strip()
|
||||
lv = (actual_cfg.get(llm_key) or "").strip()
|
||||
if pv and lv and pv != lv:
|
||||
warnings.append(
|
||||
f"LLM config mismatch: --settings has '{pv}' for {pipeline_key}, "
|
||||
f"but LLMService uses '{lv}'. "
|
||||
f"The pipeline's --settings path ({pipeline_cfg.get('settings_path', '?')}) "
|
||||
"may differ from where SettingsManager reads."
|
||||
)
|
||||
if not warnings and actual_cfg:
|
||||
warnings.append(
|
||||
"✅ LLM config matches between pipeline --settings and LLMService."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1.5: preprocessing audit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _run_preprocessing_audit(
|
||||
entries: List[RepoEntry],
|
||||
output_dir: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the preprocessing audit and save results."""
|
||||
_phase_header("Preprocessing audit")
|
||||
print(f" Auditing {len(entries)} repos ...", file=sys.stderr)
|
||||
|
||||
readmes_dir = os.path.join(output_dir, "readmes")
|
||||
t0 = time.perf_counter()
|
||||
records, summary = await run_audit(entries, readmes_dir=readmes_dir)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
# Save audit data
|
||||
audit_path = os.path.join(output_dir, "preprocessing_audit.json")
|
||||
with open(audit_path, "w", encoding="utf-8") as fh:
|
||||
json.dump(
|
||||
{
|
||||
"summary": summary,
|
||||
"records": audit_records_to_serializable(records),
|
||||
},
|
||||
fh,
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
print(f" Audit complete: {len(records)} repos in {elapsed:.0f}s", file=sys.stderr)
|
||||
print(f" Section extraction activated: {summary.get('section_extraction_pct', 0)}%", file=sys.stderr)
|
||||
print(f" Basename in extracted section: {summary.get('basename_in_section_pct', 0)}%", file=sys.stderr)
|
||||
print(f" Avg compression: {summary.get('avg_compression_pct', 0)}%", file=sys.stderr)
|
||||
print(f" Avg cleaned length: {summary.get('avg_cleaned_length', 0)} chars", file=sys.stderr)
|
||||
print(f" Audit data: {audit_path}", file=sys.stderr)
|
||||
|
||||
if summary.get("top_flags"):
|
||||
print(" Top flags:", file=sys.stderr)
|
||||
for flag, count in summary["top_flags"][:5]:
|
||||
print(f" - {flag}: {count}x", file=sys.stderr)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _run_enrichment(
|
||||
model_paths: List[str],
|
||||
repos: List[str],
|
||||
@@ -170,7 +303,7 @@ def _collect_enriched_metadata(
|
||||
errors, metadata.
|
||||
"""
|
||||
enriched: List[Dict[str, Any]] = []
|
||||
# Build a lookup from repo_id → enrichment result
|
||||
# Build a lookup from repo_id to enrichment result
|
||||
result_lookup: Dict[str, Dict[str, Any]] = {}
|
||||
for r in results:
|
||||
result_lookup[r["repo_id"]] = r
|
||||
@@ -214,29 +347,43 @@ async def main(argv: List[str]) -> int:
|
||||
output_dir = os.path.abspath(os.path.expanduser(args.output))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# ---- Phase 0: Initialise shared state ----
|
||||
_phase_header("Initialise")
|
||||
settings = load_settings(args.settings)
|
||||
logger.info(
|
||||
"LLM config: provider=%s model=%s api_base=%s",
|
||||
"LLM config from --settings: provider=%s model=%s api_base=%s",
|
||||
settings["llm_provider"],
|
||||
settings["llm_model"],
|
||||
settings["llm_api_base"],
|
||||
)
|
||||
# Load the production base model list (replaces the old hardcoded list)
|
||||
await init_supported_base_models()
|
||||
|
||||
# ---- Phase 1: Load repo IDs & construct initial metadata ----
|
||||
_phase_header("Load repo IDs & construct initial metadata")
|
||||
repos = load_repo_ids(args.models, max_models=args.sample if args.sample > 0 else None)
|
||||
model_paths = create_all_initial_metadata(
|
||||
repos, output_dir, skip_existing=True,
|
||||
# ---- Load entries ----
|
||||
_phase_header("Load repo entries & construct initial metadata")
|
||||
entries = load_repo_ids(args.models, max_models=args.sample if args.sample > 0 else None)
|
||||
model_paths, repo_ids = create_all_initial_metadata(
|
||||
entries, output_dir, skip_existing=True,
|
||||
)
|
||||
print(f" {len(model_paths)} repos ready", file=sys.stderr)
|
||||
|
||||
# ---- Phase 1.5: Preprocessing audit ----
|
||||
audit_summary: Dict[str, Any] = {}
|
||||
t_start = time.perf_counter()
|
||||
audit_summary = await _run_preprocessing_audit(entries, output_dir)
|
||||
|
||||
if args.audit_only:
|
||||
total_wall = time.perf_counter() - t_start
|
||||
print(f"\n Audit-only done in {total_wall:.0f}s", file=sys.stderr)
|
||||
print(f" Audit data: {output_dir}/preprocessing_audit.json", file=sys.stderr)
|
||||
return 0
|
||||
|
||||
# ---- Phase 2: Enrichment ----
|
||||
enrichment_results: List[Dict[str, Any]] = []
|
||||
t_start = time.perf_counter()
|
||||
if not args.no_enrich:
|
||||
_phase_header("Enrich metadata via LLM")
|
||||
enrichment_out = await _run_enrichment(
|
||||
model_paths, repos, output_dir, args.timeout, args.verbose,
|
||||
model_paths, repo_ids, output_dir, args.timeout, args.verbose,
|
||||
)
|
||||
enrichment_results = enrichment_out["results"]
|
||||
else:
|
||||
@@ -246,7 +393,7 @@ async def main(argv: List[str]) -> int:
|
||||
|
||||
# ---- Phase 3: Evaluation ----
|
||||
_phase_header("Evaluate enriched metadata")
|
||||
enriched = _collect_enriched_metadata(model_paths, repos, enrichment_results)
|
||||
enriched = _collect_enriched_metadata(model_paths, repo_ids, enrichment_results)
|
||||
scores = evaluate_batch(enriched)
|
||||
agg = aggregate_scores(scores)
|
||||
print(
|
||||
@@ -274,8 +421,18 @@ async def main(argv: List[str]) -> int:
|
||||
"max_s": round(max(durations), 1),
|
||||
}
|
||||
|
||||
save_json_report(agg, scores, enrichment_results, output_dir, duration_summary)
|
||||
generate_markdown_report(agg, scores, output_dir, duration_summary)
|
||||
# Check LLM config consistency after enrichment (LLMService is now initialized)
|
||||
actual_llm_cfg = _get_actual_llm_config()
|
||||
config_warnings = _compare_llm_config(settings, actual_llm_cfg)
|
||||
|
||||
save_json_report(
|
||||
agg, scores, enrichment_results, output_dir, duration_summary,
|
||||
audit_summary=audit_summary, config_warnings=config_warnings,
|
||||
)
|
||||
generate_markdown_report(
|
||||
agg, scores, output_dir, duration_summary,
|
||||
audit_summary=audit_summary, config_warnings=config_warnings,
|
||||
)
|
||||
|
||||
# ---- Final summary ----
|
||||
total_wall = time.perf_counter() - t_start
|
||||
|
||||
Reference in New Issue
Block a user