mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
feat: implement same_base update strategy for model annotations
Add support for configurable update flag strategy with new "same_base" mode that considers base model versions when determining update availability. The strategy is controlled by the "update_flag_strategy" setting. When strategy is set to "same_base": - Uses get_records_bulk instead of has_updates_bulk - Compares model versions against highest local versions per base model - Provides more granular update detection based on base model relationships Fallback to existing bulk or individual update checks when: - Strategy is not "same_base" - Bulk operations fail - Records are unavailable This enables more precise update flagging for models sharing common bases.
This commit is contained in:
@@ -267,20 +267,49 @@ class BaseModelService(ABC):
|
|||||||
if not ordered_ids:
|
if not ordered_ids:
|
||||||
return annotated
|
return annotated
|
||||||
|
|
||||||
|
strategy_value = self.settings.get("update_flag_strategy")
|
||||||
|
if isinstance(strategy_value, str) and strategy_value.strip():
|
||||||
|
strategy = strategy_value.strip().lower()
|
||||||
|
else:
|
||||||
|
strategy = "any"
|
||||||
|
same_base_mode = strategy == "same_base"
|
||||||
|
|
||||||
|
records = None
|
||||||
resolved: Optional[Dict[int, bool]] = None
|
resolved: Optional[Dict[int, bool]] = None
|
||||||
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
if same_base_mode:
|
||||||
if callable(bulk_method):
|
record_method = getattr(self.update_service, "get_records_bulk", None)
|
||||||
try:
|
if callable(record_method):
|
||||||
resolved = await bulk_method(self.model_type, ordered_ids)
|
try:
|
||||||
except Exception as exc:
|
records = await record_method(self.model_type, ordered_ids)
|
||||||
logger.error(
|
resolved = {
|
||||||
"Failed to resolve update status in bulk for %s models (%s): %s",
|
model_id: record.has_update()
|
||||||
self.model_type,
|
for model_id, record in records.items()
|
||||||
ordered_ids,
|
}
|
||||||
exc,
|
except Exception as exc:
|
||||||
exc_info=True,
|
logger.error(
|
||||||
)
|
"Failed to resolve update records in bulk for %s models (%s): %s",
|
||||||
resolved = None
|
self.model_type,
|
||||||
|
ordered_ids,
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
records = None
|
||||||
|
resolved = None
|
||||||
|
|
||||||
|
if resolved is None:
|
||||||
|
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
||||||
|
if callable(bulk_method):
|
||||||
|
try:
|
||||||
|
resolved = await bulk_method(self.model_type, ordered_ids)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"Failed to resolve update status in bulk for %s models (%s): %s",
|
||||||
|
self.model_type,
|
||||||
|
ordered_ids,
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
resolved = None
|
||||||
|
|
||||||
if resolved is None:
|
if resolved is None:
|
||||||
tasks = [
|
tasks = [
|
||||||
@@ -301,8 +330,24 @@ class BaseModelService(ABC):
|
|||||||
resolved[model_id] = bool(result)
|
resolved[model_id] = bool(result)
|
||||||
|
|
||||||
for model_id, items_for_id in id_to_items.items():
|
for model_id, items_for_id in id_to_items.items():
|
||||||
flag = bool(resolved.get(model_id, False))
|
default_flag = bool(resolved.get(model_id, False)) if resolved else False
|
||||||
|
record = records.get(model_id) if records else None
|
||||||
|
base_highest_versions = (
|
||||||
|
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {}
|
||||||
|
)
|
||||||
for item in items_for_id:
|
for item in items_for_id:
|
||||||
|
if same_base_mode and record is not None:
|
||||||
|
base_model = self._extract_base_model(item)
|
||||||
|
normalized_base = self._normalize_base_model_name(base_model)
|
||||||
|
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None
|
||||||
|
if threshold_version is None:
|
||||||
|
threshold_version = self._extract_version_id(item)
|
||||||
|
flag = record.has_update_for_base(
|
||||||
|
threshold_version,
|
||||||
|
base_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
flag = default_flag
|
||||||
item['update_available'] = flag
|
item['update_available'] = flag
|
||||||
|
|
||||||
return annotated
|
return annotated
|
||||||
@@ -320,6 +365,70 @@ class BaseModelService(ABC):
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_version_id(item: Dict) -> Optional[int]:
|
||||||
|
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai, dict):
|
||||||
|
return None
|
||||||
|
value = civitai.get('id')
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_base_model(item: Dict) -> Optional[str]:
|
||||||
|
value = item.get('base_model')
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
candidate = value.strip()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
candidate = str(value).strip()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return candidate if candidate else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_base_model_name(value: Optional[str]) -> Optional[str]:
|
||||||
|
"""Return a lowercased, trimmed base model name for comparison."""
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
candidate = value.strip()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
candidate = str(value).strip()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return candidate.lower() if candidate else None
|
||||||
|
|
||||||
|
def _build_highest_local_versions_by_base(self, record) -> Dict[str, int]:
|
||||||
|
"""Return the highest local version id known for each normalized base model."""
|
||||||
|
|
||||||
|
if record is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
highest_by_base: Dict[str, int] = {}
|
||||||
|
for version in getattr(record, "versions", []):
|
||||||
|
if not getattr(version, "is_in_library", False):
|
||||||
|
continue
|
||||||
|
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None))
|
||||||
|
if normalized_base is None:
|
||||||
|
continue
|
||||||
|
version_id = getattr(version, "version_id", None)
|
||||||
|
if version_id is None:
|
||||||
|
continue
|
||||||
|
current_max = highest_by_base.get(normalized_base)
|
||||||
|
if current_max is None or version_id > current_max:
|
||||||
|
highest_by_base[normalized_base] = version_id
|
||||||
|
|
||||||
|
return highest_by_base
|
||||||
|
|
||||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||||
"""Apply pagination to filtered data"""
|
"""Apply pagination to filtered data"""
|
||||||
total_items = len(data)
|
total_items = len(data)
|
||||||
|
|||||||
@@ -17,6 +17,41 @@ from ..utils.preview_selection import select_preview_media
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_int(value) -> Optional[int]:
|
||||||
|
"""Safely convert a value to an integer."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_string(value) -> Optional[str]:
|
||||||
|
"""Return a stripped string or None if the value is empty."""
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
stripped = value.strip()
|
||||||
|
return stripped or None
|
||||||
|
try:
|
||||||
|
normalized = str(value).strip()
|
||||||
|
return normalized or None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_base_model(value) -> Optional[str]:
|
||||||
|
"""Normalize base-model names for case-insensitive comparison."""
|
||||||
|
|
||||||
|
normalized = _normalize_string(value)
|
||||||
|
if normalized is None:
|
||||||
|
return None
|
||||||
|
return normalized.lower()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelVersionRecord:
|
class ModelVersionRecord:
|
||||||
"""Persisted metadata for a single model version."""
|
"""Persisted metadata for a single model version."""
|
||||||
@@ -85,6 +120,47 @@ class ModelUpdateRecord:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def has_update_for_base(
|
||||||
|
self,
|
||||||
|
local_version_id: Optional[int],
|
||||||
|
local_base_model: Optional[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Return True when a newer remote version with the same base model exists."""
|
||||||
|
|
||||||
|
if self.should_ignore_model:
|
||||||
|
return False
|
||||||
|
|
||||||
|
normalized_base = _normalize_base_model(local_base_model)
|
||||||
|
if normalized_base is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
threshold = _normalize_int(local_version_id)
|
||||||
|
if threshold is None:
|
||||||
|
highest_local = None
|
||||||
|
for version in self.versions:
|
||||||
|
if not version.is_in_library:
|
||||||
|
continue
|
||||||
|
version_base = _normalize_base_model(version.base_model)
|
||||||
|
if version_base != normalized_base:
|
||||||
|
continue
|
||||||
|
if highest_local is None or version.version_id > highest_local:
|
||||||
|
highest_local = version.version_id
|
||||||
|
threshold = highest_local
|
||||||
|
|
||||||
|
if threshold is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for version in self.versions:
|
||||||
|
if version.is_in_library or version.should_ignore:
|
||||||
|
continue
|
||||||
|
version_base = _normalize_base_model(version.base_model)
|
||||||
|
if version_base != normalized_base:
|
||||||
|
continue
|
||||||
|
if version.version_id > threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ModelUpdateService:
|
class ModelUpdateService:
|
||||||
"""Persist and query remote model version metadata."""
|
"""Persist and query remote model version metadata."""
|
||||||
@@ -628,6 +704,20 @@ class ModelUpdateService:
|
|||||||
for model_id in normalized_ids
|
for model_id in normalized_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_records_bulk(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_ids: Sequence[int],
|
||||||
|
) -> Dict[int, ModelUpdateRecord]:
|
||||||
|
"""Return cached update records for the requested models."""
|
||||||
|
|
||||||
|
normalized_ids = self._normalize_sequence(model_ids)
|
||||||
|
if not normalized_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
return self._get_records_bulk(model_type, normalized_ids)
|
||||||
|
|
||||||
async def _refresh_single_model(
|
async def _refresh_single_model(
|
||||||
self,
|
self,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
@@ -799,7 +889,7 @@ class ModelUpdateService:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
for key, value in response.items():
|
for key, value in response.items():
|
||||||
normalized_key = self._normalize_int(key)
|
normalized_key = _normalize_int(key)
|
||||||
if normalized_key is None:
|
if normalized_key is None:
|
||||||
continue
|
continue
|
||||||
if isinstance(value, Mapping):
|
if isinstance(value, Mapping):
|
||||||
@@ -832,8 +922,8 @@ class ModelUpdateService:
|
|||||||
civitai = item.get("civitai") if isinstance(item, dict) else None
|
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||||
if not isinstance(civitai, dict):
|
if not isinstance(civitai, dict):
|
||||||
continue
|
continue
|
||||||
model_id = self._normalize_int(civitai.get("modelId"))
|
model_id = _normalize_int(civitai.get("modelId"))
|
||||||
version_id = self._normalize_int(civitai.get("id"))
|
version_id = _normalize_int(civitai.get("id"))
|
||||||
if model_id is None or version_id is None:
|
if model_id is None or version_id is None:
|
||||||
continue
|
continue
|
||||||
if target_set is not None and model_id not in target_set:
|
if target_set is not None and model_id not in target_set:
|
||||||
@@ -973,35 +1063,14 @@ class ModelUpdateService:
|
|||||||
return True
|
return True
|
||||||
return (now - record.last_checked_at) >= self._ttl_seconds
|
return (now - record.last_checked_at) >= self._ttl_seconds
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_int(value) -> Optional[int]:
|
|
||||||
try:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return int(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
||||||
normalized = [
|
normalized = [
|
||||||
item
|
item
|
||||||
for item in (self._normalize_int(value) for value in values)
|
for item in (_normalize_int(value) for value in values)
|
||||||
if item is not None
|
if item is not None
|
||||||
]
|
]
|
||||||
return sorted(dict.fromkeys(normalized))
|
return sorted(dict.fromkeys(normalized))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_string(value) -> Optional[str]:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
if isinstance(value, str):
|
|
||||||
stripped = value.strip()
|
|
||||||
return stripped or None
|
|
||||||
try:
|
|
||||||
return str(value)
|
|
||||||
except Exception: # pragma: no cover - defensive conversion
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]:
|
def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]:
|
||||||
if not isinstance(response, Mapping):
|
if not isinstance(response, Mapping):
|
||||||
return None
|
return None
|
||||||
@@ -1014,12 +1083,12 @@ class ModelUpdateService:
|
|||||||
for index, entry in enumerate(versions):
|
for index, entry in enumerate(versions):
|
||||||
if not isinstance(entry, Mapping):
|
if not isinstance(entry, Mapping):
|
||||||
continue
|
continue
|
||||||
version_id = self._normalize_int(entry.get("id"))
|
version_id = _normalize_int(entry.get("id"))
|
||||||
if version_id is None:
|
if version_id is None:
|
||||||
continue
|
continue
|
||||||
name = self._normalize_string(entry.get("name"))
|
name = _normalize_string(entry.get("name"))
|
||||||
base_model = self._normalize_string(entry.get("baseModel"))
|
base_model = _normalize_string(entry.get("baseModel"))
|
||||||
released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
|
released_at = _normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
|
||||||
size_bytes = self._extract_size_bytes(entry.get("files"))
|
size_bytes = self._extract_size_bytes(entry.get("files"))
|
||||||
preview_url = self._extract_preview_url(entry.get("images"))
|
preview_url = self._extract_preview_url(entry.get("images"))
|
||||||
extracted.append(
|
extracted.append(
|
||||||
@@ -1152,11 +1221,11 @@ class ModelUpdateService:
|
|||||||
name=row["name"],
|
name=row["name"],
|
||||||
base_model=row["base_model"],
|
base_model=row["base_model"],
|
||||||
released_at=row["released_at"],
|
released_at=row["released_at"],
|
||||||
size_bytes=self._normalize_int(row["size_bytes"]),
|
size_bytes=_normalize_int(row["size_bytes"]),
|
||||||
preview_url=row["preview_url"],
|
preview_url=row["preview_url"],
|
||||||
is_in_library=bool(row["is_in_library"]),
|
is_in_library=bool(row["is_in_library"]),
|
||||||
should_ignore=bool(row["should_ignore"]),
|
should_ignore=bool(row["should_ignore"]),
|
||||||
sort_index=self._normalize_int(row["sort_index"]) or 0,
|
sort_index=_normalize_int(row["sort_index"]) or 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
|
|||||||
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
|
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
|
||||||
"model_name_display": "model_name",
|
"model_name_display": "model_name",
|
||||||
"model_card_footer_action": "example_images",
|
"model_card_footer_action": "example_images",
|
||||||
|
"update_flag_strategy": "any",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from py.services.model_query import (
|
|||||||
SearchStrategy,
|
SearchStrategy,
|
||||||
SortParams,
|
SortParams,
|
||||||
)
|
)
|
||||||
|
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
|
||||||
from py.utils.models import BaseModelMetadata
|
from py.utils.models import BaseModelMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -98,6 +99,25 @@ class StubUpdateService:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class StubUpdateServiceWithRecords(StubUpdateService):
|
||||||
|
def __init__(self, records, *, bulk_error: bool = False):
|
||||||
|
decisions = {
|
||||||
|
model_id: record.has_update()
|
||||||
|
for model_id, record in records.items()
|
||||||
|
}
|
||||||
|
super().__init__(decisions, bulk_error=bulk_error)
|
||||||
|
self.records = dict(records)
|
||||||
|
self.records_bulk_calls = []
|
||||||
|
|
||||||
|
async def get_records_bulk(self, model_type, model_ids):
|
||||||
|
self.records_bulk_calls.append((model_type, list(model_ids)))
|
||||||
|
return {
|
||||||
|
model_id: self.records[model_id]
|
||||||
|
for model_id in model_ids
|
||||||
|
if model_id in self.records
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_paginated_data_uses_injected_collaborators():
|
async def test_get_paginated_data_uses_injected_collaborators():
|
||||||
data = [
|
data = [
|
||||||
@@ -461,6 +481,198 @@ async def test_get_paginated_data_annotates_update_flags_with_bulk_dedup():
|
|||||||
assert response["total_pages"] == 1
|
assert response["total_pages"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_flag_strategy_same_base_prefers_matching_base():
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"model_name": "Pony Version",
|
||||||
|
"civitai": {"modelId": 1, "id": 10, "baseModel": "Pony"},
|
||||||
|
"base_model": "Pony",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "Flux Version",
|
||||||
|
"civitai": {"modelId": 1, "id": 20, "baseModel": "Flux 1.D"},
|
||||||
|
"base_model": "Flux 1.D",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
repository = StubRepository(items)
|
||||||
|
filter_set = PassThroughFilterSet()
|
||||||
|
search_strategy = NoSearchStrategy()
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type="stub",
|
||||||
|
model_id=1,
|
||||||
|
versions=[
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=10,
|
||||||
|
name="Pony Local",
|
||||||
|
base_model="Pony",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=True,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=0,
|
||||||
|
),
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=20,
|
||||||
|
name="Flux Local",
|
||||||
|
base_model="Flux 1.D",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=True,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=1,
|
||||||
|
),
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=30,
|
||||||
|
name="Pony Remote",
|
||||||
|
base_model="Pony",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=False,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=2,
|
||||||
|
),
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=40,
|
||||||
|
name="SDXL Remote",
|
||||||
|
base_model="SDXL",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=False,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=3,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
last_checked_at=None,
|
||||||
|
should_ignore_model=False,
|
||||||
|
)
|
||||||
|
update_service = StubUpdateServiceWithRecords({1: record})
|
||||||
|
settings = StubSettings({"update_flag_strategy": "same_base"})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
update_service=update_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
sort_by="name:asc",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update_service.records_bulk_calls == [("stub", [1])]
|
||||||
|
assert update_service.bulk_calls == []
|
||||||
|
assert len(response["items"]) == 2
|
||||||
|
flags = {item["model_name"]: item["update_available"] for item in response["items"]}
|
||||||
|
assert flags["Pony Version"] is True
|
||||||
|
assert flags["Flux Version"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_flag_strategy_same_base_honors_latest_local_version():
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"model_name": "Pony v0.1",
|
||||||
|
"civitai": {"modelId": 1, "id": 101, "baseModel": "Pony"},
|
||||||
|
"base_model": "Pony",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "Pony v0.3",
|
||||||
|
"civitai": {"modelId": 1, "id": 103, "baseModel": "Pony"},
|
||||||
|
"base_model": "Pony",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
repository = StubRepository(items)
|
||||||
|
filter_set = PassThroughFilterSet()
|
||||||
|
search_strategy = NoSearchStrategy()
|
||||||
|
record = ModelUpdateRecord(
|
||||||
|
model_type="stub",
|
||||||
|
model_id=1,
|
||||||
|
versions=[
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=101,
|
||||||
|
name="Old Pony",
|
||||||
|
base_model="Pony",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=True,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=0,
|
||||||
|
),
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=102,
|
||||||
|
name="Pony Remote",
|
||||||
|
base_model="Pony",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=False,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=1,
|
||||||
|
),
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=103,
|
||||||
|
name="Middle Pony",
|
||||||
|
base_model="Pony",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=True,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=2,
|
||||||
|
),
|
||||||
|
ModelVersionRecord(
|
||||||
|
version_id=104,
|
||||||
|
name="Latest Pony",
|
||||||
|
base_model="Pony",
|
||||||
|
released_at=None,
|
||||||
|
size_bytes=None,
|
||||||
|
preview_url=None,
|
||||||
|
is_in_library=True,
|
||||||
|
should_ignore=False,
|
||||||
|
sort_index=3,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
last_checked_at=None,
|
||||||
|
should_ignore_model=False,
|
||||||
|
)
|
||||||
|
update_service = StubUpdateServiceWithRecords({1: record})
|
||||||
|
settings = StubSettings({"update_flag_strategy": "same_base"})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
update_service=update_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
sort_by="name:asc",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update_service.records_bulk_calls == [("stub", [1])]
|
||||||
|
flags = {item["model_name"]: item["update_available"] for item in response["items"]}
|
||||||
|
assert flags["Pony v0.1"] is False
|
||||||
|
assert flags["Pony v0.3"] is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_paginated_data_filters_update_available_only():
|
async def test_get_paginated_data_filters_update_available_only():
|
||||||
items = [
|
items = [
|
||||||
|
|||||||
@@ -52,11 +52,11 @@ class NotFoundProvider:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def make_version(version_id, *, in_library, should_ignore=False):
|
def make_version(version_id, *, in_library, base_model=None, should_ignore=False):
|
||||||
return ModelVersionRecord(
|
return ModelVersionRecord(
|
||||||
version_id=version_id,
|
version_id=version_id,
|
||||||
name=None,
|
name=None,
|
||||||
base_model=None,
|
base_model=base_model,
|
||||||
released_at=None,
|
released_at=None,
|
||||||
size_bytes=None,
|
size_bytes=None,
|
||||||
preview_url=None,
|
preview_url=None,
|
||||||
@@ -147,6 +147,25 @@ def test_has_update_detects_newer_remote_version():
|
|||||||
assert record.has_update() is True
|
assert record.has_update() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_update_for_base_matches_same_base_model():
|
||||||
|
record = make_record(
|
||||||
|
make_version(5, in_library=True, base_model="Pony"),
|
||||||
|
make_version(6, in_library=False, base_model="Pony"),
|
||||||
|
make_version(7, in_library=False, base_model="Flux.1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert record.has_update_for_base(5, "Pony") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_update_for_base_rejects_other_base_models():
|
||||||
|
record = make_record(
|
||||||
|
make_version(10, in_library=True, base_model="Flux"),
|
||||||
|
make_version(20, in_library=False, base_model="SDXL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert record.has_update_for_base(10, "Flux") is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||||
db_path = tmp_path / "updates.sqlite"
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
|||||||
Reference in New Issue
Block a user