mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat: add update service dependency and has_update filter
- Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors - Add has_update query parameter filter to model listing handler - Update BaseModelService to accept optional update_service parameter These changes enable model update functionality across different model types and provide filtering capability for models with available updates.
This commit is contained in:
@@ -20,8 +20,8 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
self.service = CheckpointService(checkpoint_scanner)
|
|
||||||
update_service = await ServiceRegistry.get_model_update_service()
|
update_service = await ServiceRegistry.get_model_update_service()
|
||||||
|
self.service = CheckpointService(checkpoint_scanner, update_service=update_service)
|
||||||
self.set_model_update_service(update_service)
|
self.set_model_update_service(update_service)
|
||||||
|
|
||||||
# Attach service dependencies
|
# Attach service dependencies
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
self.service = EmbeddingService(embedding_scanner)
|
|
||||||
update_service = await ServiceRegistry.get_model_update_service()
|
update_service = await ServiceRegistry.get_model_update_service()
|
||||||
|
self.service = EmbeddingService(embedding_scanner, update_service=update_service)
|
||||||
self.set_model_update_service(update_service)
|
self.set_model_update_service(update_service)
|
||||||
|
|
||||||
# Attach service dependencies
|
# Attach service dependencies
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
|
from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import jinja2
|
import jinja2
|
||||||
@@ -166,6 +166,11 @@ class ModelListingHandler:
|
|||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
has_update = request.query.get("has_update", "false")
|
||||||
|
has_update_filter = (
|
||||||
|
has_update.lower() in {"1", "true", "yes"} if isinstance(has_update, str) else False
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"page": page,
|
"page": page,
|
||||||
"page_size": page_size,
|
"page_size": page_size,
|
||||||
@@ -178,6 +183,7 @@ class ModelListingHandler:
|
|||||||
"search_options": search_options,
|
"search_options": search_options,
|
||||||
"hash_filters": hash_filters,
|
"hash_filters": hash_filters,
|
||||||
"favorites_only": favorites_only,
|
"favorites_only": favorites_only,
|
||||||
|
"has_update": has_update_filter,
|
||||||
**self._parse_specific_params(request),
|
**self._parse_specific_params(request),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.service = LoraService(lora_scanner)
|
|
||||||
update_service = await ServiceRegistry.get_model_update_service()
|
update_service = await ServiceRegistry.get_model_update_service()
|
||||||
|
self.service = LoraService(lora_scanner, update_service=update_service)
|
||||||
self.set_model_update_service(update_service)
|
self.set_model_update_service(update_service)
|
||||||
|
|
||||||
# Attach service dependencies
|
# Attach service dependencies
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Type
|
import asyncio
|
||||||
|
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -10,6 +11,9 @@ from .settings_manager import get_settings_manager
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model_update_service import ModelUpdateService
|
||||||
|
|
||||||
class BaseModelService(ABC):
|
class BaseModelService(ABC):
|
||||||
"""Base service class for all model types"""
|
"""Base service class for all model types"""
|
||||||
|
|
||||||
@@ -23,6 +27,7 @@ class BaseModelService(ABC):
|
|||||||
filter_set: Optional[ModelFilterSet] = None,
|
filter_set: Optional[ModelFilterSet] = None,
|
||||||
search_strategy: Optional[SearchStrategy] = None,
|
search_strategy: Optional[SearchStrategy] = None,
|
||||||
settings_provider: Optional[SettingsProvider] = None,
|
settings_provider: Optional[SettingsProvider] = None,
|
||||||
|
update_service: Optional["ModelUpdateService"] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the service.
|
"""Initialize the service.
|
||||||
|
|
||||||
@@ -34,6 +39,7 @@ class BaseModelService(ABC):
|
|||||||
filter_set: Filter component controlling folder/tag/favorites logic.
|
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||||
search_strategy: Search component for fuzzy/text matching.
|
search_strategy: Search component for fuzzy/text matching.
|
||||||
settings_provider: Settings object; defaults to the global settings manager.
|
settings_provider: Settings object; defaults to the global settings manager.
|
||||||
|
update_service: Service used to determine whether models have remote updates available.
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.scanner = scanner
|
self.scanner = scanner
|
||||||
@@ -42,6 +48,7 @@ class BaseModelService(ABC):
|
|||||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||||
self.search_strategy = search_strategy or SearchStrategy()
|
self.search_strategy = search_strategy or SearchStrategy()
|
||||||
|
self.update_service = update_service
|
||||||
|
|
||||||
async def get_paginated_data(
|
async def get_paginated_data(
|
||||||
self,
|
self,
|
||||||
@@ -56,6 +63,7 @@ class BaseModelService(ABC):
|
|||||||
search_options: dict = None,
|
search_options: dict = None,
|
||||||
hash_filters: dict = None,
|
hash_filters: dict = None,
|
||||||
favorites_only: bool = False,
|
favorites_only: bool = False,
|
||||||
|
has_update: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Get paginated and filtered model data"""
|
"""Get paginated and filtered model data"""
|
||||||
@@ -85,6 +93,9 @@ class BaseModelService(ABC):
|
|||||||
|
|
||||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||||
|
|
||||||
|
if has_update:
|
||||||
|
filtered_data = await self._apply_update_filter(filtered_data)
|
||||||
|
|
||||||
return self._paginate(filtered_data, page, page_size)
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
|
||||||
@@ -145,6 +156,59 @@ class BaseModelService(ABC):
|
|||||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
async def _apply_update_filter(self, data: List[Dict]) -> List[Dict]:
|
||||||
|
"""Filter models to those with remote updates available when requested."""
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
if self.update_service is None:
|
||||||
|
logger.warning(
|
||||||
|
"Requested has_update filter for %s models but update service is unavailable",
|
||||||
|
self.model_type,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidates: List[tuple[Dict, int]] = []
|
||||||
|
for item in data:
|
||||||
|
model_id = self._extract_model_id(item)
|
||||||
|
if model_id is not None:
|
||||||
|
candidates.append((item, model_id))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
self.update_service.has_update(self.model_type, model_id)
|
||||||
|
for _, model_id in candidates
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
filtered: List[Dict] = []
|
||||||
|
for (item, model_id), result in zip(candidates, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(
|
||||||
|
"Failed to resolve update status for model %s (%s): %s",
|
||||||
|
model_id,
|
||||||
|
self.model_type,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if result:
|
||||||
|
filtered.append(item)
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_model_id(item: Dict) -> Optional[int]:
|
||||||
|
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai, dict):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
value = civitai.get('modelId')
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||||
"""Apply pagination to filtered data"""
|
"""Apply pagination to filtered data"""
|
||||||
total_items = len(data)
|
total_items = len(data)
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ logger = logging.getLogger(__name__)
|
|||||||
class CheckpointService(BaseModelService):
|
class CheckpointService(BaseModelService):
|
||||||
"""Checkpoint-specific service implementation"""
|
"""Checkpoint-specific service implementation"""
|
||||||
|
|
||||||
def __init__(self, scanner):
|
def __init__(self, scanner, update_service=None):
|
||||||
"""Initialize Checkpoint service
|
"""Initialize Checkpoint service
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scanner: Checkpoint scanner instance
|
scanner: Checkpoint scanner instance
|
||||||
|
update_service: Optional service for remote update tracking.
|
||||||
"""
|
"""
|
||||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||||
"""Format Checkpoint data for API response"""
|
"""Format Checkpoint data for API response"""
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ logger = logging.getLogger(__name__)
|
|||||||
class EmbeddingService(BaseModelService):
|
class EmbeddingService(BaseModelService):
|
||||||
"""Embedding-specific service implementation"""
|
"""Embedding-specific service implementation"""
|
||||||
|
|
||||||
def __init__(self, scanner):
|
def __init__(self, scanner, update_service=None):
|
||||||
"""Initialize Embedding service
|
"""Initialize Embedding service
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scanner: Embedding scanner instance
|
scanner: Embedding scanner instance
|
||||||
|
update_service: Optional service for remote update tracking.
|
||||||
"""
|
"""
|
||||||
super().__init__("embedding", scanner, EmbeddingMetadata)
|
super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||||
"""Format Embedding data for API response"""
|
"""Format Embedding data for API response"""
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ logger = logging.getLogger(__name__)
|
|||||||
class LoraService(BaseModelService):
|
class LoraService(BaseModelService):
|
||||||
"""LoRA-specific service implementation"""
|
"""LoRA-specific service implementation"""
|
||||||
|
|
||||||
def __init__(self, scanner):
|
def __init__(self, scanner, update_service=None):
|
||||||
"""Initialize LoRA service
|
"""Initialize LoRA service
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scanner: LoRA scanner instance
|
scanner: LoRA scanner instance
|
||||||
|
update_service: Optional service for remote update tracking.
|
||||||
"""
|
"""
|
||||||
super().__init__("lora", scanner, LoraMetadata)
|
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, lora_data: Dict) -> Dict:
|
async def format_response(self, lora_data: Dict) -> Dict:
|
||||||
"""Format LoRA data for API response"""
|
"""Format LoRA data for API response"""
|
||||||
|
|||||||
@@ -67,6 +67,19 @@ class StubSearchStrategy:
|
|||||||
return list(self.search_result)
|
return list(self.search_result)
|
||||||
|
|
||||||
|
|
||||||
|
class StubUpdateService:
|
||||||
|
def __init__(self, decisions):
|
||||||
|
self.decisions = dict(decisions)
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def has_update(self, model_type, model_id):
|
||||||
|
self.calls.append((model_type, model_id))
|
||||||
|
result = self.decisions.get(model_id, False)
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_paginated_data_uses_injected_collaborators():
|
async def test_get_paginated_data_uses_injected_collaborators():
|
||||||
data = [
|
data = [
|
||||||
@@ -272,3 +285,111 @@ async def test_get_paginated_data_paginates_without_search():
|
|||||||
assert response["page"] == 2
|
assert response["page"] == 2
|
||||||
assert response["page_size"] == 2
|
assert response["page_size"] == 2
|
||||||
assert response["total_pages"] == 3
|
assert response["total_pages"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_filters_by_update_status():
|
||||||
|
items = [
|
||||||
|
{"model_name": "A", "civitai": {"modelId": 1}},
|
||||||
|
{"model_name": "B", "civitai": {"modelId": 2}},
|
||||||
|
{"model_name": "C", "civitai": {"modelId": 3}},
|
||||||
|
]
|
||||||
|
repository = StubRepository(items)
|
||||||
|
filter_set = PassThroughFilterSet()
|
||||||
|
search_strategy = NoSearchStrategy()
|
||||||
|
update_service = StubUpdateService({1: True, 2: False, 3: True})
|
||||||
|
settings = StubSettings({})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
update_service=update_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=5,
|
||||||
|
sort_by="name:asc",
|
||||||
|
has_update=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)]
|
||||||
|
assert response["items"] == [items[0], items[2]]
|
||||||
|
assert response["total"] == 2
|
||||||
|
assert response["page"] == 1
|
||||||
|
assert response["page_size"] == 5
|
||||||
|
assert response["total_pages"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_has_update_without_service_returns_empty():
|
||||||
|
items = [
|
||||||
|
{"model_name": "A", "civitai": {"modelId": 1}},
|
||||||
|
{"model_name": "B", "civitai": {"modelId": 2}},
|
||||||
|
]
|
||||||
|
repository = StubRepository(items)
|
||||||
|
filter_set = PassThroughFilterSet()
|
||||||
|
search_strategy = NoSearchStrategy()
|
||||||
|
settings = StubSettings({})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
sort_by="name:asc",
|
||||||
|
has_update=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["items"] == []
|
||||||
|
assert response["total"] == 0
|
||||||
|
assert response["total_pages"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_skips_items_when_update_check_fails():
|
||||||
|
items = [
|
||||||
|
{"model_name": "A", "civitai": {"modelId": 1}},
|
||||||
|
{"model_name": "B", "civitai": {"modelId": 2}},
|
||||||
|
]
|
||||||
|
repository = StubRepository(items)
|
||||||
|
filter_set = PassThroughFilterSet()
|
||||||
|
search_strategy = NoSearchStrategy()
|
||||||
|
update_service = StubUpdateService({1: True, 2: RuntimeError("boom")})
|
||||||
|
settings = StubSettings({})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
update_service=update_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
sort_by="name:asc",
|
||||||
|
has_update=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update_service.calls == [("stub", 1), ("stub", 2)]
|
||||||
|
assert response["items"] == [items[0]]
|
||||||
|
assert response["total"] == 1
|
||||||
|
assert response["total_pages"] == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user