mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
refactor(routes): limit update endpoints to essentials
This commit is contained in:
404
py/services/model_update_service.py
Normal file
404
py/services/model_update_service.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Service for tracking remote model version updates."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
|
||||
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelUpdateRecord:
|
||||
"""Representation of a persisted update record."""
|
||||
|
||||
model_type: str
|
||||
model_id: int
|
||||
largest_version_id: Optional[int]
|
||||
version_ids: List[int]
|
||||
in_library_version_ids: List[int]
|
||||
last_checked_at: Optional[float]
|
||||
should_ignore: bool
|
||||
|
||||
def has_update(self) -> bool:
|
||||
"""Return True when remote versions exceed the local library."""
|
||||
|
||||
if self.should_ignore or not self.version_ids:
|
||||
return False
|
||||
local_versions = set(self.in_library_version_ids)
|
||||
return any(version_id not in local_versions for version_id in self.version_ids)
|
||||
|
||||
|
||||
class ModelUpdateService:
|
||||
"""Persist and query remote model version metadata."""
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS model_update_status (
|
||||
model_type TEXT NOT NULL,
|
||||
model_id INTEGER NOT NULL,
|
||||
largest_version_id INTEGER,
|
||||
version_ids TEXT,
|
||||
in_library_version_ids TEXT,
|
||||
last_checked_at REAL,
|
||||
should_ignore INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (model_type, model_id)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str, *, ttl_seconds: int = 24 * 60 * 60) -> None:
|
||||
self._db_path = db_path
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._lock = asyncio.Lock()
|
||||
self._schema_initialized = False
|
||||
self._ensure_directory()
|
||||
self._initialize_schema()
|
||||
|
||||
def _ensure_directory(self) -> None:
|
||||
directory = os.path.dirname(self._db_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _initialize_schema(self) -> None:
|
||||
if self._schema_initialized:
|
||||
return
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.executescript(self._SCHEMA)
|
||||
self._schema_initialized = True
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.error("Failed to initialize update schema: %s", exc, exc_info=True)
|
||||
raise
|
||||
|
||||
async def refresh_for_model_type(
|
||||
self,
|
||||
model_type: str,
|
||||
scanner,
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Dict[int, ModelUpdateRecord]:
|
||||
"""Refresh update information for every model present in the cache."""
|
||||
|
||||
local_versions = await self._collect_local_versions(scanner)
|
||||
results: Dict[int, ModelUpdateRecord] = {}
|
||||
for model_id, version_ids in local_versions.items():
|
||||
record = await self._refresh_single_model(
|
||||
model_type,
|
||||
model_id,
|
||||
version_ids,
|
||||
metadata_provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
if record:
|
||||
results[model_id] = record
|
||||
return results
|
||||
|
||||
async def refresh_single_model(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: int,
|
||||
scanner,
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Optional[ModelUpdateRecord]:
|
||||
"""Refresh update information for a specific model id."""
|
||||
|
||||
local_versions = await self._collect_local_versions(scanner)
|
||||
version_ids = local_versions.get(model_id, [])
|
||||
return await self._refresh_single_model(
|
||||
model_type,
|
||||
model_id,
|
||||
version_ids,
|
||||
metadata_provider,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
|
||||
async def update_in_library_versions(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: int,
|
||||
version_ids: Sequence[int],
|
||||
) -> ModelUpdateRecord:
|
||||
"""Persist a new set of in-library version identifiers."""
|
||||
|
||||
normalized_versions = self._normalize_sequence(version_ids)
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id if existing else None,
|
||||
version_ids=list(existing.version_ids) if existing else [],
|
||||
in_library_version_ids=normalized_versions,
|
||||
last_checked_at=existing.last_checked_at if existing else None,
|
||||
should_ignore=existing.should_ignore if existing else False,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def set_should_ignore(
|
||||
self, model_type: str, model_id: int, should_ignore: bool
|
||||
) -> ModelUpdateRecord:
|
||||
"""Toggle the ignore flag for a model."""
|
||||
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing:
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id,
|
||||
version_ids=list(existing.version_ids),
|
||||
in_library_version_ids=list(existing.in_library_version_ids),
|
||||
last_checked_at=existing.last_checked_at,
|
||||
should_ignore=should_ignore,
|
||||
)
|
||||
else:
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=None,
|
||||
version_ids=[],
|
||||
in_library_version_ids=[],
|
||||
last_checked_at=None,
|
||||
should_ignore=should_ignore,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
||||
"""Return a cached record without triggering remote fetches."""
|
||||
|
||||
async with self._lock:
|
||||
return self._get_record(model_type, model_id)
|
||||
|
||||
async def has_update(self, model_type: str, model_id: int) -> bool:
|
||||
"""Determine if a model has updates pending."""
|
||||
|
||||
record = await self.get_record(model_type, model_id)
|
||||
return record.has_update() if record else False
|
||||
|
||||
async def _refresh_single_model(
|
||||
self,
|
||||
model_type: str,
|
||||
model_id: int,
|
||||
local_versions: Sequence[int],
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Optional[ModelUpdateRecord]:
|
||||
normalized_local = self._normalize_sequence(local_versions)
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing and existing.should_ignore and not force_refresh:
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id,
|
||||
version_ids=list(existing.version_ids),
|
||||
in_library_version_ids=normalized_local,
|
||||
last_checked_at=existing.last_checked_at,
|
||||
should_ignore=True,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
should_fetch = force_refresh or not existing or self._is_stale(existing, now)
|
||||
# release lock during network request
|
||||
fetched_versions: List[int] | None = None
|
||||
if metadata_provider and should_fetch:
|
||||
try:
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.error(
|
||||
"Failed to fetch versions for model %s (%s): %s",
|
||||
model_id,
|
||||
model_type,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
fetched_versions = self._extract_version_ids(response)
|
||||
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing and existing.should_ignore and not force_refresh:
|
||||
# Ignore state could have flipped while awaiting provider
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=existing.largest_version_id,
|
||||
version_ids=list(existing.version_ids),
|
||||
in_library_version_ids=normalized_local,
|
||||
last_checked_at=existing.last_checked_at,
|
||||
should_ignore=True,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
version_ids = (
|
||||
fetched_versions
|
||||
if fetched_versions is not None
|
||||
else (list(existing.version_ids) if existing else [])
|
||||
)
|
||||
largest = max(version_ids) if version_ids else None
|
||||
last_checked = now if fetched_versions is not None else (
|
||||
existing.last_checked_at if existing else None
|
||||
)
|
||||
record = ModelUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
largest_version_id=largest,
|
||||
version_ids=version_ids,
|
||||
in_library_version_ids=normalized_local,
|
||||
last_checked_at=last_checked,
|
||||
should_ignore=existing.should_ignore if existing else False,
|
||||
)
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
||||
cache = await scanner.get_cached_data()
|
||||
mapping: Dict[int, set[int]] = {}
|
||||
if not cache or not getattr(cache, "raw_data", None):
|
||||
return {}
|
||||
|
||||
for item in cache.raw_data:
|
||||
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
continue
|
||||
model_id = self._normalize_int(civitai.get("modelId"))
|
||||
version_id = self._normalize_int(civitai.get("id"))
|
||||
if model_id is None or version_id is None:
|
||||
continue
|
||||
mapping.setdefault(model_id, set()).add(version_id)
|
||||
|
||||
return {model_id: sorted(ids) for model_id, ids in mapping.items()}
|
||||
|
||||
def _is_stale(self, record: ModelUpdateRecord, now: float) -> bool:
|
||||
if record.last_checked_at is None:
|
||||
return True
|
||||
return (now - record.last_checked_at) >= self._ttl_seconds
|
||||
|
||||
@staticmethod
|
||||
def _normalize_int(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
||||
normalized = [
|
||||
item
|
||||
for item in (self._normalize_int(value) for value in values)
|
||||
if item is not None
|
||||
]
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
|
||||
def _extract_version_ids(self, response) -> List[int]:
|
||||
if not isinstance(response, Mapping):
|
||||
return []
|
||||
versions = response.get("modelVersions")
|
||||
if not isinstance(versions, Iterable):
|
||||
return []
|
||||
normalized = []
|
||||
for entry in versions:
|
||||
if isinstance(entry, Mapping):
|
||||
normalized_id = self._normalize_int(entry.get("id"))
|
||||
else:
|
||||
normalized_id = self._normalize_int(entry)
|
||||
if normalized_id is not None:
|
||||
normalized.append(normalized_id)
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
|
||||
def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT model_type, model_id, largest_version_id, version_ids,
|
||||
in_library_version_ids, last_checked_at, should_ignore
|
||||
FROM model_update_status
|
||||
WHERE model_type = ? AND model_id = ?
|
||||
""",
|
||||
(model_type, model_id),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return ModelUpdateRecord(
|
||||
model_type=row["model_type"],
|
||||
model_id=int(row["model_id"]),
|
||||
largest_version_id=self._normalize_int(row["largest_version_id"]),
|
||||
version_ids=self._deserialize_json_array(row["version_ids"]),
|
||||
in_library_version_ids=self._deserialize_json_array(
|
||||
row["in_library_version_ids"]
|
||||
),
|
||||
last_checked_at=row["last_checked_at"],
|
||||
should_ignore=bool(row["should_ignore"]),
|
||||
)
|
||||
|
||||
def _upsert_record(self, record: ModelUpdateRecord) -> None:
|
||||
payload = (
|
||||
record.model_type,
|
||||
record.model_id,
|
||||
record.largest_version_id,
|
||||
json.dumps(record.version_ids),
|
||||
json.dumps(record.in_library_version_ids),
|
||||
record.last_checked_at,
|
||||
1 if record.should_ignore else 0,
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO model_update_status (
|
||||
model_type, model_id, largest_version_id, version_ids,
|
||||
in_library_version_ids, last_checked_at, should_ignore
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(model_type, model_id) DO UPDATE SET
|
||||
largest_version_id = excluded.largest_version_id,
|
||||
version_ids = excluded.version_ids,
|
||||
in_library_version_ids = excluded.in_library_version_ids,
|
||||
last_checked_at = excluded.last_checked_at,
|
||||
should_ignore = excluded.should_ignore
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_json_array(value) -> List[int]:
|
||||
if not value:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return []
|
||||
if isinstance(data, list):
|
||||
normalized = []
|
||||
for entry in data:
|
||||
try:
|
||||
normalized.append(int(entry))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
return []
|
||||
|
||||
@@ -81,6 +81,11 @@ class PersistentModelCache:
|
||||
def is_enabled(self) -> bool:
|
||||
return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1"
|
||||
|
||||
def get_database_path(self) -> str:
|
||||
"""Expose the resolved SQLite database path."""
|
||||
|
||||
return self._db_path
|
||||
|
||||
def load_cache(self, model_type: str) -> Optional[PersistedCacheData]:
|
||||
if not self.is_enabled():
|
||||
return None
|
||||
|
||||
@@ -128,23 +128,45 @@ class ServiceRegistry:
|
||||
async def get_civitai_client(cls):
|
||||
"""Get or create CivitAI client instance"""
|
||||
service_name = "civitai_client"
|
||||
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civitai_client import CivitaiClient
|
||||
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_model_update_service(cls):
|
||||
"""Get or create the model update tracking service."""
|
||||
|
||||
service_name = "model_update_service"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
from .model_update_service import ModelUpdateService
|
||||
from .persistent_model_cache import get_persistent_cache
|
||||
|
||||
cache = get_persistent_cache()
|
||||
service = ModelUpdateService(cache.get_database_path())
|
||||
cls._services[service_name] = service
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return service
|
||||
|
||||
@classmethod
|
||||
async def get_civarchive_client(cls):
|
||||
"""Get or create CivArchive client instance"""
|
||||
|
||||
Reference in New Issue
Block a user