feat(usage_count): sorting by usage_count + usage_count on ModelCard

This commit is contained in:
stone9k
2025-12-12 16:39:24 +01:00
parent 817de3a0ae
commit 56143eb170
18 changed files with 110 additions and 13 deletions

View File

@@ -7,6 +7,7 @@ import os
from ..utils.constants import VALID_LORA_TYPES
from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager
from ..utils.usage_stats import UsageStats
from .model_query import (
FilterCriteria,
ModelCacheRepository,
@@ -81,7 +82,10 @@ class BaseModelService(ABC):
"""Get paginated and filtered model data"""
sort_params = self.cache_repository.parse_sort(sort_by)
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
if sort_params.key == 'usage':
sorted_data = await self._fetch_with_usage_sort(sort_params)
else:
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
if hash_filters:
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
@@ -132,6 +136,37 @@ class BaseModelService(ABC):
)
return paginated
async def _fetch_with_usage_sort(self, sort_params):
"""Fetch data sorted by usage count (desc/asc)."""
cache = await self.cache_repository.get_cache()
raw_items = cache.raw_data or []
# Map model type to usage stats bucket
bucket_map = {
'lora': 'loras',
'checkpoint': 'checkpoints',
# 'embedding': 'embeddings', # TODO: Enable when embedding usage tracking is implemented
}
bucket_key = bucket_map.get(self.model_type, '')
usage_stats = UsageStats()
stats = await usage_stats.get_stats()
usage_bucket = stats.get(bucket_key, {}) if bucket_key else {}
annotated = []
for item in raw_items:
sha = (item.get('sha256') or '').lower()
usage_info = usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
usage_count = usage_info.get('total', 0) if isinstance(usage_info, dict) else 0
annotated.append({**item, 'usage_count': usage_count})
reverse = sort_params.order == 'desc'
annotated.sort(
key=lambda x: (x.get('usage_count', 0), x.get('model_name', '').lower()),
reverse=reverse
)
return annotated
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering"""