Merge branch 'main' into fix-symlink

This commit is contained in:
pixelpaws
2026-01-26 17:29:31 +08:00
committed by GitHub
17 changed files with 224282 additions and 100 deletions

View File

@@ -9,6 +9,7 @@ import json
import urllib.parse
import time
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template
# Use an environment variable to control standalone mode
@@ -241,9 +242,8 @@ class Config:
return os.path.normpath(path).replace(os.sep, '/')
def _get_symlink_cache_path(self) -> Path:
cache_dir = Path(get_settings_dir(create=True)) / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "symlink_map.json"
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
return Path(canonical_path)
def _symlink_roots(self) -> List[str]:
roots: List[str] = []
@@ -322,14 +322,28 @@ class Config:
def _load_persisted_cache_into_mappings(self) -> bool:
"""Load the symlink cache and store its fingerprint for comparison."""
cache_path = self._get_symlink_cache_path()
if not cache_path.exists():
return False
try:
with cache_path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
except Exception as exc:
logger.info("Failed to load symlink cache %s: %s", cache_path, exc)
# Check canonical path first, then legacy paths for migration
paths_to_check = [cache_path]
legacy_paths = get_legacy_cache_paths(CacheType.SYMLINK)
paths_to_check.extend(Path(p) for p in legacy_paths if p != str(cache_path))
loaded_path = None
payload = None
for check_path in paths_to_check:
if not check_path.exists():
continue
try:
with check_path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
loaded_path = check_path
break
except Exception as exc:
logger.info("Failed to load symlink cache %s: %s", check_path, exc)
continue
if payload is None:
return False
if not isinstance(payload, dict):
@@ -349,7 +363,37 @@ class Config:
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link)
self._path_mappings = normalized_mappings
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
# Log migration if loaded from legacy path
if loaded_path is not None and loaded_path != cache_path:
logger.info(
"Symlink cache migrated from %s (will save to %s)",
loaded_path,
cache_path,
)
try:
if loaded_path.exists():
loaded_path.unlink()
logger.info("Cleaned up legacy symlink cache: %s", loaded_path)
try:
parent_dir = loaded_path.parent
if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
parent_dir.rmdir()
logger.info("Removed empty legacy cache directory: %s", parent_dir)
except Exception:
pass
except Exception as exc:
logger.warning(
"Failed to cleanup legacy symlink cache %s: %s",
loaded_path,
exc,
)
else:
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
return True
def _validate_cached_mappings(self) -> bool:

View File

@@ -17,7 +17,7 @@ class PromptLM:
"text": (
"AUTOCOMPLETE_TEXT_PROMPT",
{
"placeholder": "Enter prompt...",
"placeholder": "Enter prompt... /char, /artist for quick tag search",
"tooltip": "The text to be encoded.",
},
),

View File

@@ -1231,12 +1231,31 @@ class CustomWordsHandler:
return web.json_response({"error": str(exc)}, status=500)
async def search_custom_words(self, request: web.Request) -> web.Response:
"""Search custom words with autocomplete."""
"""Search custom words with autocomplete.
Query parameters:
search: The search term to match against.
limit: Maximum number of results to return (default: 20).
category: Optional category filter. Can be:
- A category name (e.g., "character", "artist", "general")
- Comma-separated category IDs (e.g., "4,11" for character)
enriched: If "true", return enriched results with category and post_count
even without category filtering.
"""
try:
search_term = request.query.get("search", "")
limit = int(request.query.get("limit", "20"))
category_param = request.query.get("category", "")
enriched_param = request.query.get("enriched", "").lower() == "true"
results = self._service.search_words(search_term, limit)
# Parse category parameter
categories = None
if category_param:
categories = self._parse_category_param(category_param)
results = self._service.search_words(
search_term, limit, categories=categories, enriched=enriched_param
)
return web.json_response({
"success": True,
@@ -1246,6 +1265,37 @@ class CustomWordsHandler:
logger.error("Error searching custom words: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
def _parse_category_param(self, param: str) -> list[int] | None:
"""Parse category parameter into list of category IDs.
Args:
param: Category parameter value (name or comma-separated IDs).
Returns:
List of category IDs, or None if parsing fails.
"""
from ...services.tag_fts_index import CATEGORY_NAME_TO_IDS
param = param.strip().lower()
if not param:
return None
# Try to parse as category name first
if param in CATEGORY_NAME_TO_IDS:
return CATEGORY_NAME_TO_IDS[param]
# Try to parse as comma-separated integers
try:
category_ids = []
for part in param.split(","):
part = part.strip()
if part:
category_ids.append(int(part))
return category_ids if category_ids else None
except ValueError:
logger.debug("Invalid category parameter: %s", param)
return None
class NodeRegistryHandler:
def __init__(

View File

@@ -2,6 +2,9 @@
This service provides functionality to parse CSV-formatted custom words,
search them with priority-based ranking, and manage storage.
It also integrates with TagFTSIndex to search the Danbooru/e621 tag database
for comprehensive autocomplete suggestions with category filtering.
"""
from __future__ import annotations
@@ -10,7 +13,7 @@ import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Union
logger = logging.getLogger(__name__)
@@ -35,6 +38,7 @@ class CustomWordsService:
- Parses CSV format: word[,priority] or word[,alias][,priority]
- Searches words with priority-based ranking
- Caches parsed words for performance
- Integrates with TagFTSIndex for Danbooru/e621 tag search
"""
_instance: Optional[CustomWordsService] = None
@@ -51,6 +55,7 @@ class CustomWordsService:
self._words_cache: Dict[str, WordEntry] = {}
self._file_path: Optional[Path] = None
self._tag_index: Optional[Any] = None # Lazy-loaded TagFTSIndex
self._initialized = True
self._determine_file_path()
@@ -98,6 +103,17 @@ class CustomWordsService:
"""Get the current file path for custom words."""
return self._file_path
def _get_tag_index(self):
"""Get or create the TagFTSIndex instance (lazy initialization)."""
if self._tag_index is None:
try:
from .tag_fts_index import get_tag_fts_index
self._tag_index = get_tag_fts_index()
except Exception as e:
logger.warning(f"Failed to initialize TagFTSIndex: {e}")
self._tag_index = None
return self._tag_index
def load_words(self) -> Dict[str, WordEntry]:
"""Load and parse words from the custom words file.
@@ -160,10 +176,20 @@ class CustomWordsService:
return words
def search_words(self, search_term: str, limit: int = 20) -> List[str]:
def search_words(
self,
search_term: str,
limit: int = 20,
categories: Optional[List[int]] = None,
enriched: bool = False
) -> Union[List[str], List[Dict[str, Any]]]:
"""Search custom words with priority-based ranking.
Matching priority:
When categories are provided or enriched is True, uses TagFTSIndex to search
the Danbooru/e621 tag database and returns enriched results with category
and post_count.
Matching priority (for custom words):
1. Words with priority (sorted by priority descending)
2. Prefix matches (word starts with search term)
3. Include matches (word contains search term)
@@ -171,10 +197,29 @@ class CustomWordsService:
Args:
search_term: The search term to match against.
limit: Maximum number of results to return.
categories: Optional list of category IDs to filter by.
When provided, searches TagFTSIndex instead of custom words.
enriched: If True, return enriched results even without category filtering.
Returns:
List of matching word texts.
List of matching word texts (when categories is None and enriched is False), or
List of dicts with tag_name, category, post_count (when categories is provided
or enriched is True).
"""
# Use TagFTSIndex when categories are specified or when explicitly requested
tag_index = self._get_tag_index()
if tag_index is not None:
# Search the tag database
results = tag_index.search(search_term, categories=categories, limit=limit)
if results:
# If categories were specified or enriched requested, return enriched results
if categories is not None or enriched:
return results
# Otherwise, convert to simple string list for backward compatibility
return [r["tag_name"] for r in results]
# Fall through to custom words if no tag results
# Fall back to custom words search
words = self._words_cache if self._words_cache else self.load_words()
if not search_term:
@@ -212,14 +257,18 @@ class CustomWordsService:
# Combine results: 20% top priority + all prefix matches + rest of priority + all include
top_priority_count = max(1, limit // 5)
results = (
text_results = (
[entry.text for entry, _ in priority_matches[:top_priority_count]]
+ [entry.text for entry, _ in prefix_matches]
+ [entry.text for entry, _ in priority_matches[top_priority_count:]]
+ [entry.text for entry, _ in include_matches]
)
return results[:limit]
# If categories were requested but tag index failed, return empty enriched format
if categories is not None:
return [{"tag_name": t, "category": 0, "post_count": 0} for t in text_results[:limit]]
return text_results[:limit]
def save_words(self, content: str) -> bool:
"""Save custom words content to file.

View File

@@ -1,13 +1,12 @@
import json
import logging
import os
import re
import sqlite3
import threading
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from ..utils.settings_paths import get_project_root, get_settings_dir
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
logger = logging.getLogger(__name__)
@@ -404,20 +403,12 @@ class PersistentModelCache:
# Internal helpers -------------------------------------------------
def _resolve_default_path(self, library_name: str) -> str:
override = os.environ.get("LORA_MANAGER_CACHE_DB")
if override:
return override
try:
settings_dir = get_settings_dir(create=True)
except Exception as exc: # pragma: no cover - defensive guard
logger.warning("Falling back to project directory for cache: %s", exc)
settings_dir = get_project_root()
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", library_name or "default")
if safe_name.lower() in ("default", ""):
legacy_path = os.path.join(settings_dir, self._DEFAULT_FILENAME)
if os.path.exists(legacy_path):
return legacy_path
return os.path.join(settings_dir, "model_cache", f"{safe_name}.sqlite")
env_override = os.environ.get("LORA_MANAGER_CACHE_DB")
return resolve_cache_path_with_migration(
CacheType.MODEL,
library_name=library_name,
env_override=env_override,
)
def _initialize_schema(self) -> None:
with self._db_lock:

View File

@@ -10,13 +10,12 @@ from __future__ import annotations
import json
import logging
import os
import re
import sqlite3
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
from ..utils.settings_paths import get_project_root, get_settings_dir
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
logger = logging.getLogger(__name__)
@@ -312,20 +311,12 @@ class PersistentRecipeCache:
# Internal helpers
def _resolve_default_path(self, library_name: str) -> str:
override = os.environ.get("LORA_MANAGER_RECIPE_CACHE_DB")
if override:
return override
try:
settings_dir = get_settings_dir(create=True)
except Exception as exc:
logger.warning("Falling back to project directory for recipe cache: %s", exc)
settings_dir = get_project_root()
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", library_name or "default")
if safe_name.lower() in ("default", ""):
legacy_path = os.path.join(settings_dir, self._DEFAULT_FILENAME)
if os.path.exists(legacy_path):
return legacy_path
return os.path.join(settings_dir, "recipe_cache", f"{safe_name}.sqlite")
env_override = os.environ.get("LORA_MANAGER_RECIPE_CACHE_DB")
return resolve_cache_path_with_migration(
CacheType.RECIPE,
library_name=library_name,
env_override=env_override,
)
def _initialize_schema(self) -> None:
with self._db_lock:

View File

@@ -15,7 +15,7 @@ import threading
import time
from typing import Any, Dict, List, Optional, Set
from ..utils.settings_paths import get_settings_dir
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
logger = logging.getLogger(__name__)
@@ -67,17 +67,11 @@ class RecipeFTSIndex:
def _resolve_default_path(self) -> str:
"""Resolve the default database path."""
override = os.environ.get("LORA_MANAGER_RECIPE_FTS_DB")
if override:
return override
try:
settings_dir = get_settings_dir(create=True)
except Exception as exc:
logger.warning("Falling back to current directory for FTS index: %s", exc)
settings_dir = "."
return os.path.join(settings_dir, self._DEFAULT_FILENAME)
env_override = os.environ.get("LORA_MANAGER_RECIPE_FTS_DB")
return resolve_cache_path_with_migration(
CacheType.RECIPE_FTS,
env_override=env_override,
)
def get_database_path(self) -> str:
"""Return the resolved database path."""

View File

@@ -0,0 +1,498 @@
"""SQLite FTS5-based full-text search index for tags.
This module provides fast tag search using SQLite's FTS5 extension,
enabling sub-100ms search times for 221k+ Danbooru/e621 tags.
"""
from __future__ import annotations
import csv
import logging
import os
import re
import sqlite3
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, Set
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
logger = logging.getLogger(__name__)
# Category definitions for Danbooru and e621
CATEGORY_NAMES = {
# Danbooru categories
0: "general",
1: "artist",
3: "copyright",
4: "character",
5: "meta",
# e621 categories
7: "general",
8: "artist",
10: "copyright",
11: "character",
12: "species",
14: "meta",
15: "lore",
}
# Map category names to their IDs (for filtering)
CATEGORY_NAME_TO_IDS = {
"general": [0, 7],
"artist": [1, 8],
"copyright": [3, 10],
"character": [4, 11],
"meta": [5, 14],
"species": [12],
"lore": [15],
}
class TagFTSIndex:
"""SQLite FTS5-based full-text search index for tags.
Provides fast prefix-based search across the Danbooru/e621 tag database.
Supports category-based filtering and returns enriched results with
post counts and category information.
"""
_DEFAULT_FILENAME = "tag_fts.sqlite"
_CSV_FILENAME = "danbooru_e621_merged.csv"
def __init__(self, db_path: Optional[str] = None, csv_path: Optional[str] = None) -> None:
"""Initialize the FTS index.
Args:
db_path: Optional path to the SQLite database file.
If not provided, uses the default location in settings directory.
csv_path: Optional path to the CSV file containing tag data.
If not provided, looks in the refs/ directory.
"""
self._db_path = db_path or self._resolve_default_db_path()
self._csv_path = csv_path or self._resolve_default_csv_path()
self._lock = threading.Lock()
self._ready = threading.Event()
self._indexing_in_progress = False
self._schema_initialized = False
self._warned_not_ready = False
# Ensure directory exists
try:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
except Exception as exc:
logger.warning("Could not create FTS index directory %s: %s", directory, exc)
def _resolve_default_db_path(self) -> str:
"""Resolve the default database path."""
env_override = os.environ.get("LORA_MANAGER_TAG_FTS_DB")
return resolve_cache_path_with_migration(
CacheType.TAG_FTS,
env_override=env_override,
)
def _resolve_default_csv_path(self) -> str:
"""Resolve the default CSV file path."""
# Look for the CSV in the refs/ directory relative to the package
package_dir = Path(__file__).parent.parent.parent
csv_path = package_dir / "refs" / self._CSV_FILENAME
return str(csv_path)
def get_database_path(self) -> str:
"""Return the resolved database path."""
return self._db_path
def get_csv_path(self) -> str:
"""Return the resolved CSV path."""
return self._csv_path
def is_ready(self) -> bool:
"""Check if the FTS index is ready for queries."""
return self._ready.is_set()
def is_indexing(self) -> bool:
"""Check if indexing is currently in progress."""
return self._indexing_in_progress
def initialize(self) -> None:
"""Initialize the database schema."""
if self._schema_initialized:
return
with self._lock:
if self._schema_initialized:
return
try:
conn = self._connect()
try:
conn.execute("PRAGMA journal_mode=WAL")
conn.executescript("""
-- FTS5 virtual table for full-text search
CREATE VIRTUAL TABLE IF NOT EXISTS tag_fts USING fts5(
tag_name,
tokenize='unicode61 remove_diacritics 2'
);
-- Tags table with metadata
CREATE TABLE IF NOT EXISTS tags (
rowid INTEGER PRIMARY KEY,
tag_name TEXT UNIQUE NOT NULL,
category INTEGER NOT NULL DEFAULT 0,
post_count INTEGER NOT NULL DEFAULT 0
);
-- Indexes for efficient filtering
CREATE INDEX IF NOT EXISTS idx_tags_category ON tags(category);
CREATE INDEX IF NOT EXISTS idx_tags_post_count ON tags(post_count DESC);
-- Index version tracking
CREATE TABLE IF NOT EXISTS fts_metadata (
key TEXT PRIMARY KEY,
value TEXT
);
""")
conn.commit()
self._schema_initialized = True
logger.debug("Tag FTS index schema initialized at %s", self._db_path)
finally:
conn.close()
except Exception as exc:
logger.error("Failed to initialize tag FTS schema: %s", exc)
def build_index(self) -> None:
"""Build the FTS index from the CSV file.
This method parses the danbooru_e621_merged.csv file and creates
the FTS index for fast searching.
"""
if self._indexing_in_progress:
logger.warning("Tag FTS indexing already in progress, skipping")
return
if not os.path.exists(self._csv_path):
logger.warning("CSV file not found at %s, cannot build tag index", self._csv_path)
return
self._indexing_in_progress = True
self._ready.clear()
start_time = time.time()
try:
self.initialize()
if not self._schema_initialized:
logger.error("Cannot build tag FTS index: schema not initialized")
return
with self._lock:
conn = self._connect()
try:
conn.execute("BEGIN")
# Clear existing data
conn.execute("DELETE FROM tag_fts")
conn.execute("DELETE FROM tags")
# Parse CSV and insert in batches
batch_size = 500
rows = []
total_inserted = 0
with open(self._csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
tag_name = row[0].strip()
if not tag_name:
continue
try:
category = int(row[1])
except (ValueError, IndexError):
category = 0
try:
post_count = int(row[2])
except (ValueError, IndexError):
post_count = 0
rows.append((tag_name, category, post_count))
if len(rows) >= batch_size:
self._insert_batch(conn, rows)
total_inserted += len(rows)
rows = []
# Insert remaining rows
if rows:
self._insert_batch(conn, rows)
total_inserted += len(rows)
# Update metadata
conn.execute(
"INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)",
("last_build_time", str(time.time()))
)
conn.execute(
"INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)",
("tag_count", str(total_inserted))
)
conn.commit()
elapsed = time.time() - start_time
logger.info("Tag FTS index built: %d tags indexed in %.2fs", total_inserted, elapsed)
finally:
conn.close()
self._ready.set()
except Exception as exc:
logger.error("Failed to build tag FTS index: %s", exc, exc_info=True)
finally:
self._indexing_in_progress = False
def _insert_batch(self, conn: sqlite3.Connection, rows: List[tuple]) -> None:
"""Insert a batch of rows into the database."""
# Insert into tags table
conn.executemany(
"INSERT OR IGNORE INTO tags (tag_name, category, post_count) VALUES (?, ?, ?)",
rows
)
# Get rowids and insert into FTS table
tag_names = [row[0] for row in rows]
placeholders = ",".join("?" * len(tag_names))
cursor = conn.execute(
f"SELECT rowid, tag_name FROM tags WHERE tag_name IN ({placeholders})",
tag_names
)
fts_rows = [(tag_name,) for rowid, tag_name in cursor.fetchall()]
if fts_rows:
conn.executemany("INSERT INTO tag_fts (tag_name) VALUES (?)", fts_rows)
def ensure_ready(self) -> bool:
"""Ensure the index is ready, building if necessary.
Returns:
True if the index is ready, False otherwise.
"""
if self.is_ready():
return True
# Check if index already exists and has data
self.initialize()
if self._schema_initialized:
count = self.get_indexed_count()
if count > 0:
self._ready.set()
logger.debug("Tag FTS index already populated with %d tags", count)
return True
# Build the index
self.build_index()
return self.is_ready()
def search(
self,
query: str,
categories: Optional[List[int]] = None,
limit: int = 20
) -> List[Dict]:
"""Search tags using FTS5 with prefix matching.
Args:
query: The search query string.
categories: Optional list of category IDs to filter by.
limit: Maximum number of results to return.
Returns:
List of dictionaries with tag_name, category, and post_count.
"""
# Ensure index is ready (lazy initialization)
if not self.ensure_ready():
if not self._warned_not_ready:
logger.debug("Tag FTS index not ready, returning empty results")
self._warned_not_ready = True
return []
if not query or not query.strip():
return []
fts_query = self._build_fts_query(query)
if not fts_query:
return []
try:
with self._lock:
conn = self._connect(readonly=True)
try:
# Build the SQL query
if categories:
placeholders = ",".join("?" * len(categories))
sql = f"""
SELECT t.tag_name, t.category, t.post_count
FROM tags t
WHERE t.tag_name IN (
SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ?
)
AND t.category IN ({placeholders})
ORDER BY t.post_count DESC
LIMIT ?
"""
params = [fts_query] + categories + [limit]
else:
sql = """
SELECT t.tag_name, t.category, t.post_count
FROM tags t
WHERE t.tag_name IN (
SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ?
)
ORDER BY t.post_count DESC
LIMIT ?
"""
params = [fts_query, limit]
cursor = conn.execute(sql, params)
results = []
for row in cursor.fetchall():
results.append({
"tag_name": row[0],
"category": row[1],
"post_count": row[2],
})
return results
finally:
conn.close()
except Exception as exc:
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
return []
def get_indexed_count(self) -> int:
"""Return the number of tags currently indexed."""
if not self._schema_initialized:
return 0
try:
with self._lock:
conn = self._connect(readonly=True)
try:
cursor = conn.execute("SELECT COUNT(*) FROM tags")
result = cursor.fetchone()
return result[0] if result else 0
finally:
conn.close()
except Exception:
return 0
def clear(self) -> bool:
"""Clear all data from the FTS index.
Returns:
True if successful, False otherwise.
"""
try:
with self._lock:
conn = self._connect()
try:
conn.execute("DELETE FROM tag_fts")
conn.execute("DELETE FROM tags")
conn.commit()
self._ready.clear()
return True
finally:
conn.close()
except Exception as exc:
logger.error("Failed to clear tag FTS index: %s", exc)
return False
# Internal helpers
def _connect(self, readonly: bool = False) -> sqlite3.Connection:
"""Create a database connection."""
uri = False
path = self._db_path
if readonly:
if not os.path.exists(path):
raise FileNotFoundError(path)
path = f"file:{path}?mode=ro"
uri = True
conn = sqlite3.connect(path, check_same_thread=False, uri=uri)
conn.row_factory = sqlite3.Row
return conn
def _build_fts_query(self, query: str) -> str:
"""Build an FTS5 query string with prefix matching.
Args:
query: The user's search query.
Returns:
FTS5 query string.
"""
# Split query into words and clean them
words = query.lower().split()
if not words:
return ""
# Escape and add prefix wildcard to each word
prefix_terms = []
for word in words:
escaped = self._escape_fts_query(word)
if escaped:
# Add prefix wildcard for substring-like matching
prefix_terms.append(f"{escaped}*")
if not prefix_terms:
return ""
# Combine terms with implicit AND (all words must match)
return " ".join(prefix_terms)
def _escape_fts_query(self, text: str) -> str:
"""Escape special FTS5 characters.
FTS5 special characters: " ( ) * : ^ -
We keep * for prefix matching but escape others.
"""
if not text:
return ""
# Replace FTS5 special characters with space
special = ['"', "(", ")", "*", ":", "^", "-", "{", "}", "[", "]"]
result = text
for char in special:
result = result.replace(char, " ")
# Collapse multiple spaces and strip
result = re.sub(r"\s+", " ", result).strip()
return result
# Singleton instance
_tag_fts_index: Optional[TagFTSIndex] = None
_tag_fts_lock = threading.Lock()
def get_tag_fts_index() -> TagFTSIndex:
"""Get the singleton TagFTSIndex instance."""
global _tag_fts_index
if _tag_fts_index is None:
with _tag_fts_lock:
if _tag_fts_index is None:
_tag_fts_index = TagFTSIndex()
return _tag_fts_index
__all__ = [
"TagFTSIndex",
"get_tag_fts_index",
"CATEGORY_NAMES",
"CATEGORY_NAME_TO_IDS",
]

421
py/utils/cache_paths.py Normal file
View File

@@ -0,0 +1,421 @@
"""Centralized cache path resolution with automatic migration support.
This module provides a unified interface for resolving cache file paths,
with automatic migration from legacy locations to the new organized
cache directory structure.
Target structure:
{settings_dir}/
└── cache/
├── symlink/
│ └── symlink_map.json
├── model/
│ └── {library_name}.sqlite
├── recipe/
│ └── {library_name}.sqlite
└── fts/
├── recipe_fts.sqlite
└── tag_fts.sqlite
"""
from __future__ import annotations
import logging
import os
import re
import shutil
from enum import Enum
from typing import List, Optional
from .settings_paths import get_project_root, get_settings_dir
logger = logging.getLogger(__name__)
class CacheType(Enum):
"""Types of cache files managed by the cache path resolver."""
MODEL = "model"
RECIPE = "recipe"
RECIPE_FTS = "recipe_fts"
TAG_FTS = "tag_fts"
SYMLINK = "symlink"
# Subdirectory structure for each cache type
_CACHE_SUBDIRS = {
CacheType.MODEL: "model",
CacheType.RECIPE: "recipe",
CacheType.RECIPE_FTS: "fts",
CacheType.TAG_FTS: "fts",
CacheType.SYMLINK: "symlink",
}
# Filename patterns for each cache type
_CACHE_FILENAMES = {
CacheType.MODEL: "{library_name}.sqlite",
CacheType.RECIPE: "{library_name}.sqlite",
CacheType.RECIPE_FTS: "recipe_fts.sqlite",
CacheType.TAG_FTS: "tag_fts.sqlite",
CacheType.SYMLINK: "symlink_map.json",
}
def get_cache_base_dir(create: bool = True) -> str:
"""Return the base cache directory path.
Args:
create: Whether to create the directory if it does not exist.
Returns:
The absolute path to the cache base directory ({settings_dir}/cache/).
"""
settings_dir = get_settings_dir(create=create)
cache_dir = os.path.join(settings_dir, "cache")
if create:
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def _sanitize_library_name(library_name: Optional[str]) -> str:
"""Sanitize a library name for use in filenames.
Args:
library_name: The library name to sanitize.
Returns:
A sanitized version safe for use in filenames.
"""
name = library_name or "default"
return re.sub(r"[^A-Za-z0-9_.-]", "_", name)
def get_cache_file_path(
cache_type: CacheType,
library_name: Optional[str] = None,
create_dir: bool = True,
) -> str:
"""Get the canonical path for a cache file.
Args:
cache_type: The type of cache file.
library_name: The library name (only used for MODEL and RECIPE types).
create_dir: Whether to create the parent directory if it does not exist.
Returns:
The absolute path to the cache file in its canonical location.
"""
cache_base = get_cache_base_dir(create=create_dir)
subdir = _CACHE_SUBDIRS[cache_type]
cache_dir = os.path.join(cache_base, subdir)
if create_dir:
os.makedirs(cache_dir, exist_ok=True)
filename_template = _CACHE_FILENAMES[cache_type]
safe_name = _sanitize_library_name(library_name)
filename = filename_template.format(library_name=safe_name)
return os.path.join(cache_dir, filename)
def get_legacy_cache_paths(
cache_type: CacheType,
library_name: Optional[str] = None,
) -> List[str]:
"""Get a list of legacy cache file paths to check for migration.
The paths are returned in order of priority (most recent first).
Args:
cache_type: The type of cache file.
library_name: The library name (only used for MODEL and RECIPE types).
Returns:
A list of potential legacy paths to check, in order of preference.
"""
try:
settings_dir = get_settings_dir(create=False)
except Exception:
settings_dir = get_project_root()
safe_name = _sanitize_library_name(library_name)
legacy_paths: List[str] = []
if cache_type == CacheType.MODEL:
# Legacy per-library path: {settings_dir}/model_cache/{library}.sqlite
legacy_paths.append(
os.path.join(settings_dir, "model_cache", f"{safe_name}.sqlite")
)
# Legacy root-level single cache (for "default" library only)
if safe_name.lower() in ("default", ""):
legacy_paths.append(os.path.join(settings_dir, "model_cache.sqlite"))
elif cache_type == CacheType.RECIPE:
# Legacy per-library path: {settings_dir}/recipe_cache/{library}.sqlite
legacy_paths.append(
os.path.join(settings_dir, "recipe_cache", f"{safe_name}.sqlite")
)
# Legacy root-level single cache (for "default" library only)
if safe_name.lower() in ("default", ""):
legacy_paths.append(os.path.join(settings_dir, "recipe_cache.sqlite"))
elif cache_type == CacheType.RECIPE_FTS:
# Legacy root-level path
legacy_paths.append(os.path.join(settings_dir, "recipe_fts.sqlite"))
elif cache_type == CacheType.TAG_FTS:
# Legacy root-level path
legacy_paths.append(os.path.join(settings_dir, "tag_fts.sqlite"))
elif cache_type == CacheType.SYMLINK:
# Current location in cache/ but without subdirectory
legacy_paths.append(
os.path.join(settings_dir, "cache", "symlink_map.json")
)
return legacy_paths
def _cleanup_legacy_file_after_migration(
legacy_path: str,
canonical_path: str,
) -> bool:
"""Safely remove a legacy file after successful migration.
Args:
legacy_path: The legacy file path to remove.
canonical_path: The canonical path where the file was copied to.
Returns:
True if cleanup succeeded, False otherwise.
"""
try:
if not os.path.exists(canonical_path):
logger.warning(
"Skipping cleanup of %s: canonical file not found at %s",
legacy_path,
canonical_path,
)
return False
legacy_size = os.path.getsize(legacy_path)
canonical_size = os.path.getsize(canonical_path)
if legacy_size != canonical_size:
logger.warning(
"Skipping cleanup of %s: file size mismatch (legacy=%d, canonical=%d)",
legacy_path,
legacy_size,
canonical_size,
)
return False
os.remove(legacy_path)
logger.info("Cleaned up legacy cache file: %s", legacy_path)
_cleanup_empty_legacy_directories(legacy_path)
return True
except Exception as exc:
logger.warning(
"Failed to cleanup legacy cache file %s: %s",
legacy_path,
exc,
)
return False
def _cleanup_empty_legacy_directories(legacy_path: str) -> None:
"""Remove empty parent directories of a legacy file.
This function only removes directories if they are empty,
using os.rmdir() which fails on non-empty directories.
Args:
legacy_path: The legacy file path whose parent directories should be cleaned.
"""
try:
parent_dir = os.path.dirname(legacy_path)
legacy_dir_names = ("model_cache", "recipe_cache")
current = parent_dir
while current:
base_name = os.path.basename(current)
if base_name in legacy_dir_names:
if os.path.isdir(current) and not os.listdir(current):
try:
os.rmdir(current)
logger.info("Removed empty legacy directory: %s", current)
except Exception:
pass
parent = os.path.dirname(current)
if parent == current:
break
current = parent
except Exception as exc:
logger.debug("Failed to cleanup empty legacy directories: %s", exc)
def resolve_cache_path_with_migration(
cache_type: CacheType,
library_name: Optional[str] = None,
env_override: Optional[str] = None,
) -> str:
"""Resolve the cache file path, migrating from legacy locations if needed.
This function performs lazy migration: on first access, it checks if the
file exists at the canonical location. If not, it looks for legacy files
and copies them to the new location. After successful migration, the
legacy file is automatically removed.
Args:
cache_type: The type of cache file.
library_name: The library name (only used for MODEL and RECIPE types).
env_override: Optional environment variable value that overrides all
path resolution. When set, returns this path directly without
any migration.
Returns:
The resolved path to use for the cache file.
"""
# Environment override bypasses all migration logic
if env_override:
return env_override
canonical_path = get_cache_file_path(cache_type, library_name, create_dir=True)
# If file already exists at canonical location, use it
if os.path.exists(canonical_path):
return canonical_path
# Check legacy paths for migration
legacy_paths = get_legacy_cache_paths(cache_type, library_name)
for legacy_path in legacy_paths:
if os.path.exists(legacy_path):
try:
shutil.copy2(legacy_path, canonical_path)
logger.info(
"Migrated %s cache from %s to %s",
cache_type.value,
legacy_path,
canonical_path,
)
_cleanup_legacy_file_after_migration(legacy_path, canonical_path)
return canonical_path
except Exception as exc:
logger.warning(
"Failed to migrate %s cache from %s: %s",
cache_type.value,
legacy_path,
exc,
)
# No legacy file found; return canonical path (will be created fresh)
return canonical_path
def get_legacy_cache_files_for_cleanup() -> List[str]:
"""Get a list of legacy cache files that can be removed after migration.
This function returns files that exist in legacy locations and have
corresponding files in the new canonical locations.
Returns:
A list of legacy file paths that are safe to remove.
"""
files_to_remove: List[str] = []
try:
settings_dir = get_settings_dir(create=False)
except Exception:
return files_to_remove
# Check each cache type for migrated legacy files
for cache_type in CacheType:
# For MODEL and RECIPE, we need to check each library
if cache_type in (CacheType.MODEL, CacheType.RECIPE):
# Check default library
_check_legacy_for_cleanup(cache_type, "default", files_to_remove)
# Check for any per-library caches in legacy directories
legacy_dir_name = "model_cache" if cache_type == CacheType.MODEL else "recipe_cache"
legacy_dir = os.path.join(settings_dir, legacy_dir_name)
if os.path.isdir(legacy_dir):
try:
for filename in os.listdir(legacy_dir):
if filename.endswith(".sqlite"):
library_name = filename[:-7] # Remove .sqlite
_check_legacy_for_cleanup(cache_type, library_name, files_to_remove)
except Exception:
pass
else:
_check_legacy_for_cleanup(cache_type, None, files_to_remove)
return files_to_remove
def _check_legacy_for_cleanup(
cache_type: CacheType,
library_name: Optional[str],
files_to_remove: List[str],
) -> None:
"""Check if a legacy cache file can be removed after migration.
Args:
cache_type: The type of cache file.
library_name: The library name (only used for MODEL and RECIPE types).
files_to_remove: List to append removable files to.
"""
canonical_path = get_cache_file_path(cache_type, library_name, create_dir=False)
if not os.path.exists(canonical_path):
return
legacy_paths = get_legacy_cache_paths(cache_type, library_name)
for legacy_path in legacy_paths:
if os.path.exists(legacy_path) and legacy_path not in files_to_remove:
files_to_remove.append(legacy_path)
def cleanup_legacy_cache_files(dry_run: bool = True) -> List[str]:
"""Remove legacy cache files that have been migrated.
Args:
dry_run: If True, only return the list of files that would be removed
without actually removing them.
Returns:
A list of files that were (or would be) removed.
"""
files = get_legacy_cache_files_for_cleanup()
if dry_run or not files:
return files
removed: List[str] = []
for file_path in files:
try:
os.remove(file_path)
removed.append(file_path)
logger.info("Removed legacy cache file: %s", file_path)
except Exception as exc:
logger.warning("Failed to remove legacy cache file %s: %s", file_path, exc)
# Try to remove empty legacy directories
try:
settings_dir = get_settings_dir(create=False)
for legacy_dir_name in ("model_cache", "recipe_cache"):
legacy_dir = os.path.join(settings_dir, legacy_dir_name)
if os.path.isdir(legacy_dir) and not os.listdir(legacy_dir):
os.rmdir(legacy_dir)
logger.info("Removed empty legacy directory: %s", legacy_dir)
except Exception:
pass
return removed