mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
Add file_path as a tie-breaker for all sort modes in ModelCache, BaseModelService, LoraService, and RecipeCache to ensure deterministic ordering when primary keys are identical. Resolves issue #859.
327 lines
11 KiB
Python
327 lines
11 KiB
Python
import asyncio
|
|
import time
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from dataclasses import dataclass, field
|
|
from natsort import natsorted
|
|
|
|
# Supported sort modes: (sort_key, order)
|
|
# order: 'asc' for ascending, 'desc' for descending
|
|
SUPPORTED_SORT_MODES = [
|
|
('name', 'asc'),
|
|
('name', 'desc'),
|
|
('date', 'asc'),
|
|
('date', 'desc'),
|
|
('size', 'asc'),
|
|
('size', 'desc'),
|
|
('usage', 'asc'),
|
|
('usage', 'desc'),
|
|
]
|
|
# Is this in use?
|
|
|
|
DISPLAY_NAME_MODES = {"model_name", "file_name"}
|
|
|
|
|
|
@dataclass
|
|
class ModelCache:
|
|
"""Cache structure for model data with extensible sorting."""
|
|
|
|
raw_data: List[Dict]
|
|
folders: List[str]
|
|
version_index: Dict[int, Dict] = field(default_factory=dict)
|
|
model_id_index: Dict[int, List[Dict[str, Any]]] = field(default_factory=dict)
|
|
name_display_mode: str = "model_name"
|
|
|
|
def __post_init__(self):
|
|
self._lock = asyncio.Lock()
|
|
# Cache for last sort: (sort_key, order) -> sorted list
|
|
self._last_sort: Tuple[str, str] = (None, None)
|
|
self._last_sorted_data: List[Dict] = []
|
|
self._normalize_raw_data()
|
|
self.name_display_mode = self._normalize_display_mode(self.name_display_mode)
|
|
# Default sort on init
|
|
asyncio.create_task(self.resort())
|
|
self.rebuild_version_index()
|
|
|
|
@staticmethod
|
|
def _normalize_display_mode(value: Optional[str]) -> str:
|
|
if isinstance(value, str) and value in DISPLAY_NAME_MODES:
|
|
return value
|
|
return "model_name"
|
|
|
|
@staticmethod
|
|
def _ensure_string(value: Any) -> str:
|
|
"""Return a safe string representation for metadata fields."""
|
|
|
|
if isinstance(value, str):
|
|
return value
|
|
if value is None:
|
|
return ""
|
|
return str(value)
|
|
|
|
def _normalize_item(self, item: Dict) -> None:
|
|
"""Ensure core metadata fields are present and string typed."""
|
|
|
|
if not isinstance(item, dict):
|
|
return
|
|
|
|
for field in ("model_name", "file_name", "folder"):
|
|
if field in item:
|
|
item[field] = self._ensure_string(item.get(field))
|
|
|
|
def _normalize_raw_data(self) -> None:
|
|
"""Normalize every cached entry before it is consumed."""
|
|
|
|
for item in self.raw_data:
|
|
self._normalize_item(item)
|
|
|
|
def _get_display_name(self, item: Dict) -> str:
|
|
"""Return the value used for name-based sorting based on display settings."""
|
|
|
|
if self.name_display_mode == "file_name":
|
|
primary = self._ensure_string(item.get("file_name"))
|
|
fallback = self._ensure_string(item.get("model_name"))
|
|
else:
|
|
primary = self._ensure_string(item.get("model_name"))
|
|
fallback = self._ensure_string(item.get("file_name"))
|
|
|
|
candidate = primary or fallback
|
|
return candidate or ""
|
|
|
|
@staticmethod
|
|
def _normalize_version_id(value: Any) -> Optional[int]:
|
|
"""Normalize a potential version identifier into an integer."""
|
|
|
|
if isinstance(value, int):
|
|
return value
|
|
if isinstance(value, str):
|
|
try:
|
|
return int(value)
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
|
|
def rebuild_version_index(self) -> None:
|
|
"""Rebuild the version and model indexes from the current raw data."""
|
|
|
|
self.version_index = {}
|
|
self.model_id_index = {}
|
|
for item in self.raw_data:
|
|
self.add_to_version_index(item)
|
|
|
|
def add_to_version_index(self, item: Dict) -> None:
|
|
"""Register a cache item in the version/model indexes if possible."""
|
|
|
|
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
|
if not isinstance(civitai_data, dict):
|
|
return
|
|
|
|
version_id = self._normalize_version_id(civitai_data.get('id'))
|
|
if version_id is None:
|
|
return
|
|
|
|
self.version_index[version_id] = item
|
|
|
|
model_id = self._normalize_version_id(civitai_data.get('modelId'))
|
|
if model_id is None:
|
|
return
|
|
|
|
descriptor = self._build_version_descriptor(item, civitai_data, version_id)
|
|
if descriptor is None:
|
|
return
|
|
|
|
versions = self.model_id_index.setdefault(model_id, [])
|
|
for index, existing in enumerate(versions):
|
|
if existing.get('versionId') == descriptor['versionId']:
|
|
versions[index] = descriptor
|
|
break
|
|
else:
|
|
versions.append(descriptor)
|
|
|
|
def remove_from_version_index(self, item: Dict) -> None:
|
|
"""Remove a cache item from the version/model indexes if present."""
|
|
|
|
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
|
if not isinstance(civitai_data, dict):
|
|
return
|
|
|
|
version_id = self._normalize_version_id(civitai_data.get('id'))
|
|
if version_id is None:
|
|
return
|
|
|
|
existing = self.version_index.get(version_id)
|
|
if existing is item or (
|
|
isinstance(existing, dict)
|
|
and existing.get('file_path') == item.get('file_path')
|
|
):
|
|
self.version_index.pop(version_id, None)
|
|
|
|
model_id = self._normalize_version_id(civitai_data.get('modelId'))
|
|
if model_id is None:
|
|
return
|
|
|
|
versions = self.model_id_index.get(model_id)
|
|
if not versions:
|
|
return
|
|
|
|
filtered = [v for v in versions if v.get('versionId') != version_id]
|
|
if filtered:
|
|
self.model_id_index[model_id] = filtered
|
|
else:
|
|
self.model_id_index.pop(model_id, None)
|
|
|
|
def _build_version_descriptor(
|
|
self,
|
|
item: Dict,
|
|
civitai_data: Dict[str, Any],
|
|
version_id: int,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Create a lightweight descriptor for a version entry."""
|
|
|
|
model_name = self._ensure_string(civitai_data.get('name'))
|
|
file_name = self._ensure_string(item.get('file_name'))
|
|
return {
|
|
'versionId': version_id,
|
|
'name': model_name,
|
|
'fileName': file_name,
|
|
}
|
|
|
|
def get_versions_by_model_id(self, model_id: Any) -> List[Dict[str, Any]]:
|
|
"""Return cached version descriptors for a given model ID."""
|
|
|
|
normalized_id = self._normalize_version_id(model_id)
|
|
if normalized_id is None:
|
|
return []
|
|
|
|
versions = self.model_id_index.get(normalized_id, [])
|
|
return [dict(version) for version in versions]
|
|
|
|
async def resort(self):
|
|
"""Resort cached data according to last sort mode if set"""
|
|
async with self._lock:
|
|
if self._last_sort != (None, None):
|
|
sort_key, order = self._last_sort
|
|
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
|
self._last_sorted_data = sorted_data
|
|
# Update folder list
|
|
# else: do nothing
|
|
|
|
all_folders = {
|
|
self._ensure_string(item.get('folder'))
|
|
for item in self.raw_data
|
|
if isinstance(item, dict)
|
|
}
|
|
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
|
self.rebuild_version_index()
|
|
|
|
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
|
"""Sort data by sort_key and order"""
|
|
start_time = time.perf_counter()
|
|
reverse = (order == 'desc')
|
|
if sort_key == 'name':
|
|
# Natural sort by configured display name, case-insensitive, with file_path as tie-breaker
|
|
result = natsorted(
|
|
data,
|
|
key=lambda x: (
|
|
self._get_display_name(x).lower(),
|
|
x.get('file_path', '').lower()
|
|
),
|
|
reverse=reverse
|
|
)
|
|
elif sort_key == 'date':
|
|
# Sort by modified timestamp, fallback to name and path for stability
|
|
result = sorted(
|
|
data,
|
|
key=lambda x: (
|
|
x.get('modified', 0.0),
|
|
self._get_display_name(x).lower(),
|
|
x.get('file_path', '').lower()
|
|
),
|
|
reverse=reverse
|
|
)
|
|
elif sort_key == 'size':
|
|
# Sort by file size, fallback to name and path for stability
|
|
result = sorted(
|
|
data,
|
|
key=lambda x: (
|
|
x.get('size', 0),
|
|
self._get_display_name(x).lower(),
|
|
x.get('file_path', '').lower()
|
|
),
|
|
reverse=reverse
|
|
)
|
|
elif sort_key == 'usage':
|
|
# Sort by usage count, fallback to 0, then name and path for stability
|
|
return sorted(
|
|
data,
|
|
key=lambda x: (
|
|
x.get('usage_count', 0),
|
|
self._get_display_name(x).lower(),
|
|
x.get('file_path', '').lower()
|
|
),
|
|
reverse=reverse
|
|
)
|
|
else:
|
|
# Fallback: no sort
|
|
result = list(data)
|
|
|
|
duration = time.perf_counter() - start_time
|
|
if duration > 0.05:
|
|
logger.debug("ModelCache._sort_data(%s, %s) for %d items took %.3fs", sort_key, order, len(data), duration)
|
|
return result
|
|
|
|
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
|
"""Get sorted data by sort_key and order, using cache if possible"""
|
|
async with self._lock:
|
|
if (sort_key, order) == self._last_sort:
|
|
return self._last_sorted_data
|
|
|
|
start_time = time.perf_counter()
|
|
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
|
self._last_sort = (sort_key, order)
|
|
self._last_sorted_data = sorted_data
|
|
|
|
duration = time.perf_counter() - start_time
|
|
if duration > 0.1:
|
|
logger.debug("ModelCache.get_sorted_data(%s, %s) took %.3fs", sort_key, order, duration)
|
|
|
|
return sorted_data
|
|
|
|
async def update_name_display_mode(self, display_mode: str) -> None:
|
|
"""Update the display mode used for name sorting and refresh cached results."""
|
|
|
|
normalized = self._normalize_display_mode(display_mode)
|
|
async with self._lock:
|
|
if self.name_display_mode == normalized:
|
|
return
|
|
|
|
self.name_display_mode = normalized
|
|
|
|
if self._last_sort[0] == 'name':
|
|
sort_key, order = self._last_sort
|
|
self._last_sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
|
|
|
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
|
"""Update preview_url for a specific model in all cached data
|
|
|
|
Args:
|
|
file_path: The file path of the model to update
|
|
preview_url: The new preview URL
|
|
preview_nsfw_level: The NSFW level of the preview
|
|
|
|
Returns:
|
|
bool: True if the update was successful, False if the model wasn't found
|
|
"""
|
|
async with self._lock:
|
|
# Update in raw_data
|
|
for item in self.raw_data:
|
|
if item['file_path'] == file_path:
|
|
item['preview_url'] = preview_url
|
|
item['preview_nsfw_level'] = preview_nsfw_level
|
|
break
|
|
else:
|
|
return False # Model not found
|
|
|
|
return True |