mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
Merge branch 'main' into fix-symlink
This commit is contained in:
66
py/config.py
66
py/config.py
@@ -9,6 +9,7 @@ import json
|
||||
import urllib.parse
|
||||
import time
|
||||
|
||||
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
|
||||
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template
|
||||
|
||||
# Use an environment variable to control standalone mode
|
||||
@@ -241,9 +242,8 @@ class Config:
|
||||
return os.path.normpath(path).replace(os.sep, '/')
|
||||
|
||||
def _get_symlink_cache_path(self) -> Path:
|
||||
cache_dir = Path(get_settings_dir(create=True)) / "cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "symlink_map.json"
|
||||
canonical_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
|
||||
return Path(canonical_path)
|
||||
|
||||
def _symlink_roots(self) -> List[str]:
|
||||
roots: List[str] = []
|
||||
@@ -322,14 +322,28 @@ class Config:
|
||||
def _load_persisted_cache_into_mappings(self) -> bool:
|
||||
"""Load the symlink cache and store its fingerprint for comparison."""
|
||||
cache_path = self._get_symlink_cache_path()
|
||||
if not cache_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
with cache_path.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
except Exception as exc:
|
||||
logger.info("Failed to load symlink cache %s: %s", cache_path, exc)
|
||||
# Check canonical path first, then legacy paths for migration
|
||||
paths_to_check = [cache_path]
|
||||
legacy_paths = get_legacy_cache_paths(CacheType.SYMLINK)
|
||||
paths_to_check.extend(Path(p) for p in legacy_paths if p != str(cache_path))
|
||||
|
||||
loaded_path = None
|
||||
payload = None
|
||||
|
||||
for check_path in paths_to_check:
|
||||
if not check_path.exists():
|
||||
continue
|
||||
try:
|
||||
with check_path.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
loaded_path = check_path
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.info("Failed to load symlink cache %s: %s", check_path, exc)
|
||||
continue
|
||||
|
||||
if payload is None:
|
||||
return False
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
@@ -349,7 +363,37 @@ class Config:
|
||||
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link)
|
||||
|
||||
self._path_mappings = normalized_mappings
|
||||
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
|
||||
|
||||
# Log migration if loaded from legacy path
|
||||
if loaded_path is not None and loaded_path != cache_path:
|
||||
logger.info(
|
||||
"Symlink cache migrated from %s (will save to %s)",
|
||||
loaded_path,
|
||||
cache_path,
|
||||
)
|
||||
|
||||
try:
|
||||
if loaded_path.exists():
|
||||
loaded_path.unlink()
|
||||
logger.info("Cleaned up legacy symlink cache: %s", loaded_path)
|
||||
|
||||
try:
|
||||
parent_dir = loaded_path.parent
|
||||
if parent_dir.name == "cache" and not any(parent_dir.iterdir()):
|
||||
parent_dir.rmdir()
|
||||
logger.info("Removed empty legacy cache directory: %s", parent_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to cleanup legacy symlink cache %s: %s",
|
||||
loaded_path,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
|
||||
|
||||
return True
|
||||
|
||||
def _validate_cached_mappings(self) -> bool:
|
||||
|
||||
@@ -17,7 +17,7 @@ class PromptLM:
|
||||
"text": (
|
||||
"AUTOCOMPLETE_TEXT_PROMPT",
|
||||
{
|
||||
"placeholder": "Enter prompt...",
|
||||
"placeholder": "Enter prompt... /char, /artist for quick tag search",
|
||||
"tooltip": "The text to be encoded.",
|
||||
},
|
||||
),
|
||||
|
||||
@@ -1231,12 +1231,31 @@ class CustomWordsHandler:
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def search_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Search custom words with autocomplete."""
|
||||
"""Search custom words with autocomplete.
|
||||
|
||||
Query parameters:
|
||||
search: The search term to match against.
|
||||
limit: Maximum number of results to return (default: 20).
|
||||
category: Optional category filter. Can be:
|
||||
- A category name (e.g., "character", "artist", "general")
|
||||
- Comma-separated category IDs (e.g., "4,11" for character)
|
||||
enriched: If "true", return enriched results with category and post_count
|
||||
even without category filtering.
|
||||
"""
|
||||
try:
|
||||
search_term = request.query.get("search", "")
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
category_param = request.query.get("category", "")
|
||||
enriched_param = request.query.get("enriched", "").lower() == "true"
|
||||
|
||||
results = self._service.search_words(search_term, limit)
|
||||
# Parse category parameter
|
||||
categories = None
|
||||
if category_param:
|
||||
categories = self._parse_category_param(category_param)
|
||||
|
||||
results = self._service.search_words(
|
||||
search_term, limit, categories=categories, enriched=enriched_param
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
@@ -1246,6 +1265,37 @@ class CustomWordsHandler:
|
||||
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def _parse_category_param(self, param: str) -> list[int] | None:
|
||||
"""Parse category parameter into list of category IDs.
|
||||
|
||||
Args:
|
||||
param: Category parameter value (name or comma-separated IDs).
|
||||
|
||||
Returns:
|
||||
List of category IDs, or None if parsing fails.
|
||||
"""
|
||||
from ...services.tag_fts_index import CATEGORY_NAME_TO_IDS
|
||||
|
||||
param = param.strip().lower()
|
||||
if not param:
|
||||
return None
|
||||
|
||||
# Try to parse as category name first
|
||||
if param in CATEGORY_NAME_TO_IDS:
|
||||
return CATEGORY_NAME_TO_IDS[param]
|
||||
|
||||
# Try to parse as comma-separated integers
|
||||
try:
|
||||
category_ids = []
|
||||
for part in param.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
category_ids.append(int(part))
|
||||
return category_ids if category_ids else None
|
||||
except ValueError:
|
||||
logger.debug("Invalid category parameter: %s", param)
|
||||
return None
|
||||
|
||||
|
||||
class NodeRegistryHandler:
|
||||
def __init__(
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
This service provides functionality to parse CSV-formatted custom words,
|
||||
search them with priority-based ranking, and manage storage.
|
||||
|
||||
It also integrates with TagFTSIndex to search the Danbooru/e621 tag database
|
||||
for comprehensive autocomplete suggestions with category filtering.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -10,7 +13,7 @@ import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,6 +38,7 @@ class CustomWordsService:
|
||||
- Parses CSV format: word[,priority] or word[,alias][,priority]
|
||||
- Searches words with priority-based ranking
|
||||
- Caches parsed words for performance
|
||||
- Integrates with TagFTSIndex for Danbooru/e621 tag search
|
||||
"""
|
||||
|
||||
_instance: Optional[CustomWordsService] = None
|
||||
@@ -51,6 +55,7 @@ class CustomWordsService:
|
||||
|
||||
self._words_cache: Dict[str, WordEntry] = {}
|
||||
self._file_path: Optional[Path] = None
|
||||
self._tag_index: Optional[Any] = None # Lazy-loaded TagFTSIndex
|
||||
self._initialized = True
|
||||
|
||||
self._determine_file_path()
|
||||
@@ -98,6 +103,17 @@ class CustomWordsService:
|
||||
"""Get the current file path for custom words."""
|
||||
return self._file_path
|
||||
|
||||
def _get_tag_index(self):
|
||||
"""Get or create the TagFTSIndex instance (lazy initialization)."""
|
||||
if self._tag_index is None:
|
||||
try:
|
||||
from .tag_fts_index import get_tag_fts_index
|
||||
self._tag_index = get_tag_fts_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize TagFTSIndex: {e}")
|
||||
self._tag_index = None
|
||||
return self._tag_index
|
||||
|
||||
def load_words(self) -> Dict[str, WordEntry]:
|
||||
"""Load and parse words from the custom words file.
|
||||
|
||||
@@ -160,10 +176,20 @@ class CustomWordsService:
|
||||
|
||||
return words
|
||||
|
||||
def search_words(self, search_term: str, limit: int = 20) -> List[str]:
|
||||
def search_words(
|
||||
self,
|
||||
search_term: str,
|
||||
limit: int = 20,
|
||||
categories: Optional[List[int]] = None,
|
||||
enriched: bool = False
|
||||
) -> Union[List[str], List[Dict[str, Any]]]:
|
||||
"""Search custom words with priority-based ranking.
|
||||
|
||||
Matching priority:
|
||||
When categories are provided or enriched is True, uses TagFTSIndex to search
|
||||
the Danbooru/e621 tag database and returns enriched results with category
|
||||
and post_count.
|
||||
|
||||
Matching priority (for custom words):
|
||||
1. Words with priority (sorted by priority descending)
|
||||
2. Prefix matches (word starts with search term)
|
||||
3. Include matches (word contains search term)
|
||||
@@ -171,10 +197,29 @@ class CustomWordsService:
|
||||
Args:
|
||||
search_term: The search term to match against.
|
||||
limit: Maximum number of results to return.
|
||||
categories: Optional list of category IDs to filter by.
|
||||
When provided, searches TagFTSIndex instead of custom words.
|
||||
enriched: If True, return enriched results even without category filtering.
|
||||
|
||||
Returns:
|
||||
List of matching word texts.
|
||||
List of matching word texts (when categories is None and enriched is False), or
|
||||
List of dicts with tag_name, category, post_count (when categories is provided
|
||||
or enriched is True).
|
||||
"""
|
||||
# Use TagFTSIndex when categories are specified or when explicitly requested
|
||||
tag_index = self._get_tag_index()
|
||||
if tag_index is not None:
|
||||
# Search the tag database
|
||||
results = tag_index.search(search_term, categories=categories, limit=limit)
|
||||
if results:
|
||||
# If categories were specified or enriched requested, return enriched results
|
||||
if categories is not None or enriched:
|
||||
return results
|
||||
# Otherwise, convert to simple string list for backward compatibility
|
||||
return [r["tag_name"] for r in results]
|
||||
# Fall through to custom words if no tag results
|
||||
|
||||
# Fall back to custom words search
|
||||
words = self._words_cache if self._words_cache else self.load_words()
|
||||
|
||||
if not search_term:
|
||||
@@ -212,14 +257,18 @@ class CustomWordsService:
|
||||
# Combine results: 20% top priority + all prefix matches + rest of priority + all include
|
||||
top_priority_count = max(1, limit // 5)
|
||||
|
||||
results = (
|
||||
text_results = (
|
||||
[entry.text for entry, _ in priority_matches[:top_priority_count]]
|
||||
+ [entry.text for entry, _ in prefix_matches]
|
||||
+ [entry.text for entry, _ in priority_matches[top_priority_count:]]
|
||||
+ [entry.text for entry, _ in include_matches]
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
# If categories were requested but tag index failed, return empty enriched format
|
||||
if categories is not None:
|
||||
return [{"tag_name": t, "category": 0, "post_count": 0} for t in text_results[:limit]]
|
||||
|
||||
return text_results[:limit]
|
||||
|
||||
def save_words(self, content: str) -> bool:
|
||||
"""Save custom words content to file.
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from ..utils.settings_paths import get_project_root, get_settings_dir
|
||||
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -404,20 +403,12 @@ class PersistentModelCache:
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
def _resolve_default_path(self, library_name: str) -> str:
|
||||
override = os.environ.get("LORA_MANAGER_CACHE_DB")
|
||||
if override:
|
||||
return override
|
||||
try:
|
||||
settings_dir = get_settings_dir(create=True)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.warning("Falling back to project directory for cache: %s", exc)
|
||||
settings_dir = get_project_root()
|
||||
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", library_name or "default")
|
||||
if safe_name.lower() in ("default", ""):
|
||||
legacy_path = os.path.join(settings_dir, self._DEFAULT_FILENAME)
|
||||
if os.path.exists(legacy_path):
|
||||
return legacy_path
|
||||
return os.path.join(settings_dir, "model_cache", f"{safe_name}.sqlite")
|
||||
env_override = os.environ.get("LORA_MANAGER_CACHE_DB")
|
||||
return resolve_cache_path_with_migration(
|
||||
CacheType.MODEL,
|
||||
library_name=library_name,
|
||||
env_override=env_override,
|
||||
)
|
||||
|
||||
def _initialize_schema(self) -> None:
|
||||
with self._db_lock:
|
||||
|
||||
@@ -10,13 +10,12 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from ..utils.settings_paths import get_project_root, get_settings_dir
|
||||
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -312,20 +311,12 @@ class PersistentRecipeCache:
|
||||
# Internal helpers
|
||||
|
||||
def _resolve_default_path(self, library_name: str) -> str:
|
||||
override = os.environ.get("LORA_MANAGER_RECIPE_CACHE_DB")
|
||||
if override:
|
||||
return override
|
||||
try:
|
||||
settings_dir = get_settings_dir(create=True)
|
||||
except Exception as exc:
|
||||
logger.warning("Falling back to project directory for recipe cache: %s", exc)
|
||||
settings_dir = get_project_root()
|
||||
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", library_name or "default")
|
||||
if safe_name.lower() in ("default", ""):
|
||||
legacy_path = os.path.join(settings_dir, self._DEFAULT_FILENAME)
|
||||
if os.path.exists(legacy_path):
|
||||
return legacy_path
|
||||
return os.path.join(settings_dir, "recipe_cache", f"{safe_name}.sqlite")
|
||||
env_override = os.environ.get("LORA_MANAGER_RECIPE_CACHE_DB")
|
||||
return resolve_cache_path_with_migration(
|
||||
CacheType.RECIPE,
|
||||
library_name=library_name,
|
||||
env_override=env_override,
|
||||
)
|
||||
|
||||
def _initialize_schema(self) -> None:
|
||||
with self._db_lock:
|
||||
|
||||
@@ -15,7 +15,7 @@ import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from ..utils.settings_paths import get_settings_dir
|
||||
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,17 +67,11 @@ class RecipeFTSIndex:
|
||||
|
||||
def _resolve_default_path(self) -> str:
|
||||
"""Resolve the default database path."""
|
||||
override = os.environ.get("LORA_MANAGER_RECIPE_FTS_DB")
|
||||
if override:
|
||||
return override
|
||||
|
||||
try:
|
||||
settings_dir = get_settings_dir(create=True)
|
||||
except Exception as exc:
|
||||
logger.warning("Falling back to current directory for FTS index: %s", exc)
|
||||
settings_dir = "."
|
||||
|
||||
return os.path.join(settings_dir, self._DEFAULT_FILENAME)
|
||||
env_override = os.environ.get("LORA_MANAGER_RECIPE_FTS_DB")
|
||||
return resolve_cache_path_with_migration(
|
||||
CacheType.RECIPE_FTS,
|
||||
env_override=env_override,
|
||||
)
|
||||
|
||||
def get_database_path(self) -> str:
|
||||
"""Return the resolved database path."""
|
||||
|
||||
498
py/services/tag_fts_index.py
Normal file
498
py/services/tag_fts_index.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""SQLite FTS5-based full-text search index for tags.
|
||||
|
||||
This module provides fast tag search using SQLite's FTS5 extension,
|
||||
enabling sub-100ms search times for 221k+ Danbooru/e621 tags.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Category definitions for Danbooru and e621
|
||||
CATEGORY_NAMES = {
|
||||
# Danbooru categories
|
||||
0: "general",
|
||||
1: "artist",
|
||||
3: "copyright",
|
||||
4: "character",
|
||||
5: "meta",
|
||||
# e621 categories
|
||||
7: "general",
|
||||
8: "artist",
|
||||
10: "copyright",
|
||||
11: "character",
|
||||
12: "species",
|
||||
14: "meta",
|
||||
15: "lore",
|
||||
}
|
||||
|
||||
# Map category names to their IDs (for filtering)
|
||||
CATEGORY_NAME_TO_IDS = {
|
||||
"general": [0, 7],
|
||||
"artist": [1, 8],
|
||||
"copyright": [3, 10],
|
||||
"character": [4, 11],
|
||||
"meta": [5, 14],
|
||||
"species": [12],
|
||||
"lore": [15],
|
||||
}
|
||||
|
||||
|
||||
class TagFTSIndex:
|
||||
"""SQLite FTS5-based full-text search index for tags.
|
||||
|
||||
Provides fast prefix-based search across the Danbooru/e621 tag database.
|
||||
Supports category-based filtering and returns enriched results with
|
||||
post counts and category information.
|
||||
"""
|
||||
|
||||
_DEFAULT_FILENAME = "tag_fts.sqlite"
|
||||
_CSV_FILENAME = "danbooru_e621_merged.csv"
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None, csv_path: Optional[str] = None) -> None:
|
||||
"""Initialize the FTS index.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the SQLite database file.
|
||||
If not provided, uses the default location in settings directory.
|
||||
csv_path: Optional path to the CSV file containing tag data.
|
||||
If not provided, looks in the refs/ directory.
|
||||
"""
|
||||
self._db_path = db_path or self._resolve_default_db_path()
|
||||
self._csv_path = csv_path or self._resolve_default_csv_path()
|
||||
self._lock = threading.Lock()
|
||||
self._ready = threading.Event()
|
||||
self._indexing_in_progress = False
|
||||
self._schema_initialized = False
|
||||
self._warned_not_ready = False
|
||||
|
||||
# Ensure directory exists
|
||||
try:
|
||||
directory = os.path.dirname(self._db_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not create FTS index directory %s: %s", directory, exc)
|
||||
|
||||
def _resolve_default_db_path(self) -> str:
|
||||
"""Resolve the default database path."""
|
||||
env_override = os.environ.get("LORA_MANAGER_TAG_FTS_DB")
|
||||
return resolve_cache_path_with_migration(
|
||||
CacheType.TAG_FTS,
|
||||
env_override=env_override,
|
||||
)
|
||||
|
||||
def _resolve_default_csv_path(self) -> str:
|
||||
"""Resolve the default CSV file path."""
|
||||
# Look for the CSV in the refs/ directory relative to the package
|
||||
package_dir = Path(__file__).parent.parent.parent
|
||||
csv_path = package_dir / "refs" / self._CSV_FILENAME
|
||||
return str(csv_path)
|
||||
|
||||
def get_database_path(self) -> str:
|
||||
"""Return the resolved database path."""
|
||||
return self._db_path
|
||||
|
||||
def get_csv_path(self) -> str:
|
||||
"""Return the resolved CSV path."""
|
||||
return self._csv_path
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the FTS index is ready for queries."""
|
||||
return self._ready.is_set()
|
||||
|
||||
def is_indexing(self) -> bool:
|
||||
"""Check if indexing is currently in progress."""
|
||||
return self._indexing_in_progress
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the database schema."""
|
||||
if self._schema_initialized:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
if self._schema_initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.executescript("""
|
||||
-- FTS5 virtual table for full-text search
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS tag_fts USING fts5(
|
||||
tag_name,
|
||||
tokenize='unicode61 remove_diacritics 2'
|
||||
);
|
||||
|
||||
-- Tags table with metadata
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
rowid INTEGER PRIMARY KEY,
|
||||
tag_name TEXT UNIQUE NOT NULL,
|
||||
category INTEGER NOT NULL DEFAULT 0,
|
||||
post_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
-- Indexes for efficient filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_tags_category ON tags(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_tags_post_count ON tags(post_count DESC);
|
||||
|
||||
-- Index version tracking
|
||||
CREATE TABLE IF NOT EXISTS fts_metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
self._schema_initialized = True
|
||||
logger.debug("Tag FTS index schema initialized at %s", self._db_path)
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize tag FTS schema: %s", exc)
|
||||
|
||||
def build_index(self) -> None:
|
||||
"""Build the FTS index from the CSV file.
|
||||
|
||||
This method parses the danbooru_e621_merged.csv file and creates
|
||||
the FTS index for fast searching.
|
||||
"""
|
||||
if self._indexing_in_progress:
|
||||
logger.warning("Tag FTS indexing already in progress, skipping")
|
||||
return
|
||||
|
||||
if not os.path.exists(self._csv_path):
|
||||
logger.warning("CSV file not found at %s, cannot build tag index", self._csv_path)
|
||||
return
|
||||
|
||||
self._indexing_in_progress = True
|
||||
self._ready.clear()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.initialize()
|
||||
if not self._schema_initialized:
|
||||
logger.error("Cannot build tag FTS index: schema not initialized")
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.execute("BEGIN")
|
||||
|
||||
# Clear existing data
|
||||
conn.execute("DELETE FROM tag_fts")
|
||||
conn.execute("DELETE FROM tags")
|
||||
|
||||
# Parse CSV and insert in batches
|
||||
batch_size = 500
|
||||
rows = []
|
||||
total_inserted = 0
|
||||
|
||||
with open(self._csv_path, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
if len(row) < 3:
|
||||
continue
|
||||
|
||||
tag_name = row[0].strip()
|
||||
if not tag_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
category = int(row[1])
|
||||
except (ValueError, IndexError):
|
||||
category = 0
|
||||
|
||||
try:
|
||||
post_count = int(row[2])
|
||||
except (ValueError, IndexError):
|
||||
post_count = 0
|
||||
|
||||
rows.append((tag_name, category, post_count))
|
||||
|
||||
if len(rows) >= batch_size:
|
||||
self._insert_batch(conn, rows)
|
||||
total_inserted += len(rows)
|
||||
rows = []
|
||||
|
||||
# Insert remaining rows
|
||||
if rows:
|
||||
self._insert_batch(conn, rows)
|
||||
total_inserted += len(rows)
|
||||
|
||||
# Update metadata
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)",
|
||||
("last_build_time", str(time.time()))
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)",
|
||||
("tag_count", str(total_inserted))
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
elapsed = time.time() - start_time
|
||||
logger.info("Tag FTS index built: %d tags indexed in %.2fs", total_inserted, elapsed)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
self._ready.set()
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to build tag FTS index: %s", exc, exc_info=True)
|
||||
finally:
|
||||
self._indexing_in_progress = False
|
||||
|
||||
def _insert_batch(self, conn: sqlite3.Connection, rows: List[tuple]) -> None:
|
||||
"""Insert a batch of rows into the database."""
|
||||
# Insert into tags table
|
||||
conn.executemany(
|
||||
"INSERT OR IGNORE INTO tags (tag_name, category, post_count) VALUES (?, ?, ?)",
|
||||
rows
|
||||
)
|
||||
|
||||
# Get rowids and insert into FTS table
|
||||
tag_names = [row[0] for row in rows]
|
||||
placeholders = ",".join("?" * len(tag_names))
|
||||
cursor = conn.execute(
|
||||
f"SELECT rowid, tag_name FROM tags WHERE tag_name IN ({placeholders})",
|
||||
tag_names
|
||||
)
|
||||
|
||||
fts_rows = [(tag_name,) for rowid, tag_name in cursor.fetchall()]
|
||||
if fts_rows:
|
||||
conn.executemany("INSERT INTO tag_fts (tag_name) VALUES (?)", fts_rows)
|
||||
|
||||
def ensure_ready(self) -> bool:
|
||||
"""Ensure the index is ready, building if necessary.
|
||||
|
||||
Returns:
|
||||
True if the index is ready, False otherwise.
|
||||
"""
|
||||
if self.is_ready():
|
||||
return True
|
||||
|
||||
# Check if index already exists and has data
|
||||
self.initialize()
|
||||
if self._schema_initialized:
|
||||
count = self.get_indexed_count()
|
||||
if count > 0:
|
||||
self._ready.set()
|
||||
logger.debug("Tag FTS index already populated with %d tags", count)
|
||||
return True
|
||||
|
||||
# Build the index
|
||||
self.build_index()
|
||||
return self.is_ready()
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
categories: Optional[List[int]] = None,
|
||||
limit: int = 20
|
||||
) -> List[Dict]:
|
||||
"""Search tags using FTS5 with prefix matching.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
categories: Optional list of category IDs to filter by.
|
||||
limit: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with tag_name, category, and post_count.
|
||||
"""
|
||||
# Ensure index is ready (lazy initialization)
|
||||
if not self.ensure_ready():
|
||||
if not self._warned_not_ready:
|
||||
logger.debug("Tag FTS index not ready, returning empty results")
|
||||
self._warned_not_ready = True
|
||||
return []
|
||||
|
||||
if not query or not query.strip():
|
||||
return []
|
||||
|
||||
fts_query = self._build_fts_query(query)
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
conn = self._connect(readonly=True)
|
||||
try:
|
||||
# Build the SQL query
|
||||
if categories:
|
||||
placeholders = ",".join("?" * len(categories))
|
||||
sql = f"""
|
||||
SELECT t.tag_name, t.category, t.post_count
|
||||
FROM tags t
|
||||
WHERE t.tag_name IN (
|
||||
SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ?
|
||||
)
|
||||
AND t.category IN ({placeholders})
|
||||
ORDER BY t.post_count DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
params = [fts_query] + categories + [limit]
|
||||
else:
|
||||
sql = """
|
||||
SELECT t.tag_name, t.category, t.post_count
|
||||
FROM tags t
|
||||
WHERE t.tag_name IN (
|
||||
SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ?
|
||||
)
|
||||
ORDER BY t.post_count DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
params = [fts_query, limit]
|
||||
|
||||
cursor = conn.execute(sql, params)
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append({
|
||||
"tag_name": row[0],
|
||||
"category": row[1],
|
||||
"post_count": row[2],
|
||||
})
|
||||
return results
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
|
||||
return []
|
||||
|
||||
def get_indexed_count(self) -> int:
|
||||
"""Return the number of tags currently indexed."""
|
||||
if not self._schema_initialized:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
conn = self._connect(readonly=True)
|
||||
try:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM tags")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def clear(self) -> bool:
|
||||
"""Clear all data from the FTS index.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.execute("DELETE FROM tag_fts")
|
||||
conn.execute("DELETE FROM tags")
|
||||
conn.commit()
|
||||
self._ready.clear()
|
||||
return True
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to clear tag FTS index: %s", exc)
|
||||
return False
|
||||
|
||||
# Internal helpers
|
||||
|
||||
def _connect(self, readonly: bool = False) -> sqlite3.Connection:
|
||||
"""Create a database connection."""
|
||||
uri = False
|
||||
path = self._db_path
|
||||
if readonly:
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(path)
|
||||
path = f"file:{path}?mode=ro"
|
||||
uri = True
|
||||
conn = sqlite3.connect(path, check_same_thread=False, uri=uri)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _build_fts_query(self, query: str) -> str:
|
||||
"""Build an FTS5 query string with prefix matching.
|
||||
|
||||
Args:
|
||||
query: The user's search query.
|
||||
|
||||
Returns:
|
||||
FTS5 query string.
|
||||
"""
|
||||
# Split query into words and clean them
|
||||
words = query.lower().split()
|
||||
if not words:
|
||||
return ""
|
||||
|
||||
# Escape and add prefix wildcard to each word
|
||||
prefix_terms = []
|
||||
for word in words:
|
||||
escaped = self._escape_fts_query(word)
|
||||
if escaped:
|
||||
# Add prefix wildcard for substring-like matching
|
||||
prefix_terms.append(f"{escaped}*")
|
||||
|
||||
if not prefix_terms:
|
||||
return ""
|
||||
|
||||
# Combine terms with implicit AND (all words must match)
|
||||
return " ".join(prefix_terms)
|
||||
|
||||
def _escape_fts_query(self, text: str) -> str:
|
||||
"""Escape special FTS5 characters.
|
||||
|
||||
FTS5 special characters: " ( ) * : ^ -
|
||||
We keep * for prefix matching but escape others.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Replace FTS5 special characters with space
|
||||
special = ['"', "(", ")", "*", ":", "^", "-", "{", "}", "[", "]"]
|
||||
result = text
|
||||
for char in special:
|
||||
result = result.replace(char, " ")
|
||||
|
||||
# Collapse multiple spaces and strip
|
||||
result = re.sub(r"\s+", " ", result).strip()
|
||||
return result
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_tag_fts_index: Optional[TagFTSIndex] = None
|
||||
_tag_fts_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_tag_fts_index() -> TagFTSIndex:
|
||||
"""Get the singleton TagFTSIndex instance."""
|
||||
global _tag_fts_index
|
||||
if _tag_fts_index is None:
|
||||
with _tag_fts_lock:
|
||||
if _tag_fts_index is None:
|
||||
_tag_fts_index = TagFTSIndex()
|
||||
return _tag_fts_index
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TagFTSIndex",
|
||||
"get_tag_fts_index",
|
||||
"CATEGORY_NAMES",
|
||||
"CATEGORY_NAME_TO_IDS",
|
||||
]
|
||||
421
py/utils/cache_paths.py
Normal file
421
py/utils/cache_paths.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""Centralized cache path resolution with automatic migration support.
|
||||
|
||||
This module provides a unified interface for resolving cache file paths,
|
||||
with automatic migration from legacy locations to the new organized
|
||||
cache directory structure.
|
||||
|
||||
Target structure:
|
||||
{settings_dir}/
|
||||
└── cache/
|
||||
├── symlink/
|
||||
│ └── symlink_map.json
|
||||
├── model/
|
||||
│ └── {library_name}.sqlite
|
||||
├── recipe/
|
||||
│ └── {library_name}.sqlite
|
||||
└── fts/
|
||||
├── recipe_fts.sqlite
|
||||
└── tag_fts.sqlite
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from .settings_paths import get_project_root, get_settings_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheType(Enum):
|
||||
"""Types of cache files managed by the cache path resolver."""
|
||||
|
||||
MODEL = "model"
|
||||
RECIPE = "recipe"
|
||||
RECIPE_FTS = "recipe_fts"
|
||||
TAG_FTS = "tag_fts"
|
||||
SYMLINK = "symlink"
|
||||
|
||||
|
||||
# Subdirectory structure for each cache type
|
||||
_CACHE_SUBDIRS = {
|
||||
CacheType.MODEL: "model",
|
||||
CacheType.RECIPE: "recipe",
|
||||
CacheType.RECIPE_FTS: "fts",
|
||||
CacheType.TAG_FTS: "fts",
|
||||
CacheType.SYMLINK: "symlink",
|
||||
}
|
||||
|
||||
# Filename patterns for each cache type
|
||||
_CACHE_FILENAMES = {
|
||||
CacheType.MODEL: "{library_name}.sqlite",
|
||||
CacheType.RECIPE: "{library_name}.sqlite",
|
||||
CacheType.RECIPE_FTS: "recipe_fts.sqlite",
|
||||
CacheType.TAG_FTS: "tag_fts.sqlite",
|
||||
CacheType.SYMLINK: "symlink_map.json",
|
||||
}
|
||||
|
||||
|
||||
def get_cache_base_dir(create: bool = True) -> str:
|
||||
"""Return the base cache directory path.
|
||||
|
||||
Args:
|
||||
create: Whether to create the directory if it does not exist.
|
||||
|
||||
Returns:
|
||||
The absolute path to the cache base directory ({settings_dir}/cache/).
|
||||
"""
|
||||
settings_dir = get_settings_dir(create=create)
|
||||
cache_dir = os.path.join(settings_dir, "cache")
|
||||
if create:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def _sanitize_library_name(library_name: Optional[str]) -> str:
|
||||
"""Sanitize a library name for use in filenames.
|
||||
|
||||
Args:
|
||||
library_name: The library name to sanitize.
|
||||
|
||||
Returns:
|
||||
A sanitized version safe for use in filenames.
|
||||
"""
|
||||
name = library_name or "default"
|
||||
return re.sub(r"[^A-Za-z0-9_.-]", "_", name)
|
||||
|
||||
|
||||
def get_cache_file_path(
|
||||
cache_type: CacheType,
|
||||
library_name: Optional[str] = None,
|
||||
create_dir: bool = True,
|
||||
) -> str:
|
||||
"""Get the canonical path for a cache file.
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache file.
|
||||
library_name: The library name (only used for MODEL and RECIPE types).
|
||||
create_dir: Whether to create the parent directory if it does not exist.
|
||||
|
||||
Returns:
|
||||
The absolute path to the cache file in its canonical location.
|
||||
"""
|
||||
cache_base = get_cache_base_dir(create=create_dir)
|
||||
subdir = _CACHE_SUBDIRS[cache_type]
|
||||
cache_dir = os.path.join(cache_base, subdir)
|
||||
|
||||
if create_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
filename_template = _CACHE_FILENAMES[cache_type]
|
||||
safe_name = _sanitize_library_name(library_name)
|
||||
filename = filename_template.format(library_name=safe_name)
|
||||
|
||||
return os.path.join(cache_dir, filename)
|
||||
|
||||
|
||||
def get_legacy_cache_paths(
|
||||
cache_type: CacheType,
|
||||
library_name: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Get a list of legacy cache file paths to check for migration.
|
||||
|
||||
The paths are returned in order of priority (most recent first).
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache file.
|
||||
library_name: The library name (only used for MODEL and RECIPE types).
|
||||
|
||||
Returns:
|
||||
A list of potential legacy paths to check, in order of preference.
|
||||
"""
|
||||
try:
|
||||
settings_dir = get_settings_dir(create=False)
|
||||
except Exception:
|
||||
settings_dir = get_project_root()
|
||||
|
||||
safe_name = _sanitize_library_name(library_name)
|
||||
legacy_paths: List[str] = []
|
||||
|
||||
if cache_type == CacheType.MODEL:
|
||||
# Legacy per-library path: {settings_dir}/model_cache/{library}.sqlite
|
||||
legacy_paths.append(
|
||||
os.path.join(settings_dir, "model_cache", f"{safe_name}.sqlite")
|
||||
)
|
||||
# Legacy root-level single cache (for "default" library only)
|
||||
if safe_name.lower() in ("default", ""):
|
||||
legacy_paths.append(os.path.join(settings_dir, "model_cache.sqlite"))
|
||||
|
||||
elif cache_type == CacheType.RECIPE:
|
||||
# Legacy per-library path: {settings_dir}/recipe_cache/{library}.sqlite
|
||||
legacy_paths.append(
|
||||
os.path.join(settings_dir, "recipe_cache", f"{safe_name}.sqlite")
|
||||
)
|
||||
# Legacy root-level single cache (for "default" library only)
|
||||
if safe_name.lower() in ("default", ""):
|
||||
legacy_paths.append(os.path.join(settings_dir, "recipe_cache.sqlite"))
|
||||
|
||||
elif cache_type == CacheType.RECIPE_FTS:
|
||||
# Legacy root-level path
|
||||
legacy_paths.append(os.path.join(settings_dir, "recipe_fts.sqlite"))
|
||||
|
||||
elif cache_type == CacheType.TAG_FTS:
|
||||
# Legacy root-level path
|
||||
legacy_paths.append(os.path.join(settings_dir, "tag_fts.sqlite"))
|
||||
|
||||
elif cache_type == CacheType.SYMLINK:
|
||||
# Current location in cache/ but without subdirectory
|
||||
legacy_paths.append(
|
||||
os.path.join(settings_dir, "cache", "symlink_map.json")
|
||||
)
|
||||
|
||||
return legacy_paths
|
||||
|
||||
|
||||
def _cleanup_legacy_file_after_migration(
|
||||
legacy_path: str,
|
||||
canonical_path: str,
|
||||
) -> bool:
|
||||
"""Safely remove a legacy file after successful migration.
|
||||
|
||||
Args:
|
||||
legacy_path: The legacy file path to remove.
|
||||
canonical_path: The canonical path where the file was copied to.
|
||||
|
||||
Returns:
|
||||
True if cleanup succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(canonical_path):
|
||||
logger.warning(
|
||||
"Skipping cleanup of %s: canonical file not found at %s",
|
||||
legacy_path,
|
||||
canonical_path,
|
||||
)
|
||||
return False
|
||||
|
||||
legacy_size = os.path.getsize(legacy_path)
|
||||
canonical_size = os.path.getsize(canonical_path)
|
||||
if legacy_size != canonical_size:
|
||||
logger.warning(
|
||||
"Skipping cleanup of %s: file size mismatch (legacy=%d, canonical=%d)",
|
||||
legacy_path,
|
||||
legacy_size,
|
||||
canonical_size,
|
||||
)
|
||||
return False
|
||||
|
||||
os.remove(legacy_path)
|
||||
logger.info("Cleaned up legacy cache file: %s", legacy_path)
|
||||
|
||||
_cleanup_empty_legacy_directories(legacy_path)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to cleanup legacy cache file %s: %s",
|
||||
legacy_path,
|
||||
exc,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _cleanup_empty_legacy_directories(legacy_path: str) -> None:
|
||||
"""Remove empty parent directories of a legacy file.
|
||||
|
||||
This function only removes directories if they are empty,
|
||||
using os.rmdir() which fails on non-empty directories.
|
||||
|
||||
Args:
|
||||
legacy_path: The legacy file path whose parent directories should be cleaned.
|
||||
"""
|
||||
try:
|
||||
parent_dir = os.path.dirname(legacy_path)
|
||||
|
||||
legacy_dir_names = ("model_cache", "recipe_cache")
|
||||
|
||||
current = parent_dir
|
||||
while current:
|
||||
base_name = os.path.basename(current)
|
||||
|
||||
if base_name in legacy_dir_names:
|
||||
if os.path.isdir(current) and not os.listdir(current):
|
||||
try:
|
||||
os.rmdir(current)
|
||||
logger.info("Removed empty legacy directory: %s", current)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parent = os.path.dirname(current)
|
||||
if parent == current:
|
||||
break
|
||||
current = parent
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to cleanup empty legacy directories: %s", exc)
|
||||
|
||||
|
||||
def resolve_cache_path_with_migration(
|
||||
cache_type: CacheType,
|
||||
library_name: Optional[str] = None,
|
||||
env_override: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Resolve the cache file path, migrating from legacy locations if needed.
|
||||
|
||||
This function performs lazy migration: on first access, it checks if the
|
||||
file exists at the canonical location. If not, it looks for legacy files
|
||||
and copies them to the new location. After successful migration, the
|
||||
legacy file is automatically removed.
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache file.
|
||||
library_name: The library name (only used for MODEL and RECIPE types).
|
||||
env_override: Optional environment variable value that overrides all
|
||||
path resolution. When set, returns this path directly without
|
||||
any migration.
|
||||
|
||||
Returns:
|
||||
The resolved path to use for the cache file.
|
||||
"""
|
||||
# Environment override bypasses all migration logic
|
||||
if env_override:
|
||||
return env_override
|
||||
|
||||
canonical_path = get_cache_file_path(cache_type, library_name, create_dir=True)
|
||||
|
||||
# If file already exists at canonical location, use it
|
||||
if os.path.exists(canonical_path):
|
||||
return canonical_path
|
||||
|
||||
# Check legacy paths for migration
|
||||
legacy_paths = get_legacy_cache_paths(cache_type, library_name)
|
||||
|
||||
for legacy_path in legacy_paths:
|
||||
if os.path.exists(legacy_path):
|
||||
try:
|
||||
shutil.copy2(legacy_path, canonical_path)
|
||||
logger.info(
|
||||
"Migrated %s cache from %s to %s",
|
||||
cache_type.value,
|
||||
legacy_path,
|
||||
canonical_path,
|
||||
)
|
||||
|
||||
_cleanup_legacy_file_after_migration(legacy_path, canonical_path)
|
||||
|
||||
return canonical_path
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to migrate %s cache from %s: %s",
|
||||
cache_type.value,
|
||||
legacy_path,
|
||||
exc,
|
||||
)
|
||||
|
||||
# No legacy file found; return canonical path (will be created fresh)
|
||||
return canonical_path
|
||||
|
||||
|
||||
def get_legacy_cache_files_for_cleanup() -> List[str]:
|
||||
"""Get a list of legacy cache files that can be removed after migration.
|
||||
|
||||
This function returns files that exist in legacy locations and have
|
||||
corresponding files in the new canonical locations.
|
||||
|
||||
Returns:
|
||||
A list of legacy file paths that are safe to remove.
|
||||
"""
|
||||
files_to_remove: List[str] = []
|
||||
|
||||
try:
|
||||
settings_dir = get_settings_dir(create=False)
|
||||
except Exception:
|
||||
return files_to_remove
|
||||
|
||||
# Check each cache type for migrated legacy files
|
||||
for cache_type in CacheType:
|
||||
# For MODEL and RECIPE, we need to check each library
|
||||
if cache_type in (CacheType.MODEL, CacheType.RECIPE):
|
||||
# Check default library
|
||||
_check_legacy_for_cleanup(cache_type, "default", files_to_remove)
|
||||
# Check for any per-library caches in legacy directories
|
||||
legacy_dir_name = "model_cache" if cache_type == CacheType.MODEL else "recipe_cache"
|
||||
legacy_dir = os.path.join(settings_dir, legacy_dir_name)
|
||||
if os.path.isdir(legacy_dir):
|
||||
try:
|
||||
for filename in os.listdir(legacy_dir):
|
||||
if filename.endswith(".sqlite"):
|
||||
library_name = filename[:-7] # Remove .sqlite
|
||||
_check_legacy_for_cleanup(cache_type, library_name, files_to_remove)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
_check_legacy_for_cleanup(cache_type, None, files_to_remove)
|
||||
|
||||
return files_to_remove
|
||||
|
||||
|
||||
def _check_legacy_for_cleanup(
|
||||
cache_type: CacheType,
|
||||
library_name: Optional[str],
|
||||
files_to_remove: List[str],
|
||||
) -> None:
|
||||
"""Check if a legacy cache file can be removed after migration.
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache file.
|
||||
library_name: The library name (only used for MODEL and RECIPE types).
|
||||
files_to_remove: List to append removable files to.
|
||||
"""
|
||||
canonical_path = get_cache_file_path(cache_type, library_name, create_dir=False)
|
||||
if not os.path.exists(canonical_path):
|
||||
return
|
||||
|
||||
legacy_paths = get_legacy_cache_paths(cache_type, library_name)
|
||||
for legacy_path in legacy_paths:
|
||||
if os.path.exists(legacy_path) and legacy_path not in files_to_remove:
|
||||
files_to_remove.append(legacy_path)
|
||||
|
||||
|
||||
def cleanup_legacy_cache_files(dry_run: bool = True) -> List[str]:
|
||||
"""Remove legacy cache files that have been migrated.
|
||||
|
||||
Args:
|
||||
dry_run: If True, only return the list of files that would be removed
|
||||
without actually removing them.
|
||||
|
||||
Returns:
|
||||
A list of files that were (or would be) removed.
|
||||
"""
|
||||
files = get_legacy_cache_files_for_cleanup()
|
||||
|
||||
if dry_run or not files:
|
||||
return files
|
||||
|
||||
removed: List[str] = []
|
||||
for file_path in files:
|
||||
try:
|
||||
os.remove(file_path)
|
||||
removed.append(file_path)
|
||||
logger.info("Removed legacy cache file: %s", file_path)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to remove legacy cache file %s: %s", file_path, exc)
|
||||
|
||||
# Try to remove empty legacy directories
|
||||
try:
|
||||
settings_dir = get_settings_dir(create=False)
|
||||
for legacy_dir_name in ("model_cache", "recipe_cache"):
|
||||
legacy_dir = os.path.join(settings_dir, legacy_dir_name)
|
||||
if os.path.isdir(legacy_dir) and not os.listdir(legacy_dir):
|
||||
os.rmdir(legacy_dir)
|
||||
logger.info("Removed empty legacy directory: %s", legacy_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return removed
|
||||
Reference in New Issue
Block a user