mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
Compare commits
5 Commits
v1.0.5
...
0ced53c059
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ced53c059 | ||
|
|
67ad68a23f | ||
|
|
d9ec9c512e | ||
|
|
0bcd8e09a9 | ||
|
|
fa049a28c8 |
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Voreinstellung \"{name}\" existiert bereits. Überschreiben?",
|
||||
"presetNamePlaceholder": "Voreinstellungsname...",
|
||||
"baseModel": "Basis-Modell",
|
||||
"baseModelSearchPlaceholder": "Basismodelle durchsuchen...",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Modelltypen",
|
||||
"license": "Lizenz",
|
||||
"noCreditRequired": "Kein Credit erforderlich",
|
||||
"allowSellingGeneratedContent": "Verkauf erlaubt",
|
||||
"noTags": "Keine Tags",
|
||||
"noBaseModelMatches": "Keine Basismodelle entsprechen der aktuellen Suche.",
|
||||
"clearAll": "Alle Filter löschen",
|
||||
"any": "Beliebig",
|
||||
"all": "Alle",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Preset \"{name}\" already exists. Overwrite?",
|
||||
"presetNamePlaceholder": "Preset name...",
|
||||
"baseModel": "Base Model",
|
||||
"baseModelSearchPlaceholder": "Search base models...",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "License",
|
||||
"noCreditRequired": "No Credit Required",
|
||||
"allowSellingGeneratedContent": "Allow Selling",
|
||||
"noTags": "No tags",
|
||||
"noBaseModelMatches": "No base models match the current search.",
|
||||
"clearAll": "Clear All Filters",
|
||||
"any": "Any",
|
||||
"all": "All",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "El preset \"{name}\" ya existe. ¿Sobrescribir?",
|
||||
"presetNamePlaceholder": "Nombre del preajuste...",
|
||||
"baseModel": "Modelo base",
|
||||
"baseModelSearchPlaceholder": "Buscar modelos base...",
|
||||
"modelTags": "Etiquetas (Top 20)",
|
||||
"modelTypes": "Tipos de modelos",
|
||||
"license": "Licencia",
|
||||
"noCreditRequired": "Sin crédito requerido",
|
||||
"allowSellingGeneratedContent": "Venta permitida",
|
||||
"noTags": "Sin etiquetas",
|
||||
"noBaseModelMatches": "Ningún modelo base coincide con la búsqueda actual.",
|
||||
"clearAll": "Limpiar todos los filtros",
|
||||
"any": "Cualquiera",
|
||||
"all": "Todos",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Le préréglage \"{name}\" existe déjà. Remplacer?",
|
||||
"presetNamePlaceholder": "Nom du préréglage...",
|
||||
"baseModel": "Modèle de base",
|
||||
"baseModelSearchPlaceholder": "Rechercher des modèles de base...",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Types de modèles",
|
||||
"license": "Licence",
|
||||
"noCreditRequired": "Crédit non requis",
|
||||
"allowSellingGeneratedContent": "Vente autorisée",
|
||||
"noTags": "Aucun tag",
|
||||
"noBaseModelMatches": "Aucun modèle de base ne correspond à la recherche actuelle.",
|
||||
"clearAll": "Effacer tous les filtres",
|
||||
"any": "N'importe quel",
|
||||
"all": "Tous",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "הפריסט \"{name}\" כבר קיים. לדרוס?",
|
||||
"presetNamePlaceholder": "שם קביעה מראש...",
|
||||
"baseModel": "מודל בסיס",
|
||||
"baseModelSearchPlaceholder": "חפש מודלי בסיס...",
|
||||
"modelTags": "תגיות (20 המובילות)",
|
||||
"modelTypes": "סוגי מודלים",
|
||||
"license": "רישיון",
|
||||
"noCreditRequired": "ללא קרדיט נדרש",
|
||||
"allowSellingGeneratedContent": "אפשר מכירה",
|
||||
"noTags": "ללא תגיות",
|
||||
"noBaseModelMatches": "אין מודלי בסיס התואמים לחיפוש הנוכחי.",
|
||||
"clearAll": "נקה את כל המסננים",
|
||||
"any": "כלשהו",
|
||||
"all": "כל התגים",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "プリセット「{name}」は既に存在します。上書きしますか?",
|
||||
"presetNamePlaceholder": "プリセット名...",
|
||||
"baseModel": "ベースモデル",
|
||||
"baseModelSearchPlaceholder": "ベースモデルを検索...",
|
||||
"modelTags": "タグ(上位20)",
|
||||
"modelTypes": "モデルタイプ",
|
||||
"license": "ライセンス",
|
||||
"noCreditRequired": "クレジット不要",
|
||||
"allowSellingGeneratedContent": "販売許可",
|
||||
"noTags": "タグなし",
|
||||
"noBaseModelMatches": "現在の検索に一致するベースモデルはありません。",
|
||||
"clearAll": "すべてのフィルタをクリア",
|
||||
"any": "いずれか",
|
||||
"all": "すべて",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "프리셋 \"{name}\"이(가) 이미 존재합니다. 덮어쓰시겠습니까?",
|
||||
"presetNamePlaceholder": "프리셋 이름...",
|
||||
"baseModel": "베이스 모델",
|
||||
"baseModelSearchPlaceholder": "베이스 모델 검색...",
|
||||
"modelTags": "태그 (상위 20개)",
|
||||
"modelTypes": "모델 유형",
|
||||
"license": "라이선스",
|
||||
"noCreditRequired": "크레딧 표기 없음",
|
||||
"allowSellingGeneratedContent": "판매 허용",
|
||||
"noTags": "태그 없음",
|
||||
"noBaseModelMatches": "현재 검색과 일치하는 베이스 모델이 없습니다.",
|
||||
"clearAll": "모든 필터 지우기",
|
||||
"any": "아무",
|
||||
"all": "모두",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Пресет \"{name}\" уже существует. Перезаписать?",
|
||||
"presetNamePlaceholder": "Имя пресета...",
|
||||
"baseModel": "Базовая модель",
|
||||
"baseModelSearchPlaceholder": "Поиск базовых моделей...",
|
||||
"modelTags": "Теги (Топ 20)",
|
||||
"modelTypes": "Типы моделей",
|
||||
"license": "Лицензия",
|
||||
"noCreditRequired": "Без указания авторства",
|
||||
"allowSellingGeneratedContent": "Продажа разрешена",
|
||||
"noTags": "Без тегов",
|
||||
"noBaseModelMatches": "Нет базовых моделей, соответствующих текущему поиску.",
|
||||
"clearAll": "Очистить все фильтры",
|
||||
"any": "Любой",
|
||||
"all": "Все",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "预设 \"{name}\" 已存在。是否覆盖?",
|
||||
"presetNamePlaceholder": "预设名称...",
|
||||
"baseModel": "基础模型",
|
||||
"baseModelSearchPlaceholder": "搜索基础模型...",
|
||||
"modelTags": "标签(前20)",
|
||||
"modelTypes": "模型类型",
|
||||
"license": "许可证",
|
||||
"noCreditRequired": "无需署名",
|
||||
"allowSellingGeneratedContent": "允许销售",
|
||||
"noTags": "无标签",
|
||||
"noBaseModelMatches": "没有基础模型符合当前搜索。",
|
||||
"clearAll": "清除所有筛选",
|
||||
"any": "任一",
|
||||
"all": "全部",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "預設 \"{name}\" 已存在。是否覆蓋?",
|
||||
"presetNamePlaceholder": "預設名稱...",
|
||||
"baseModel": "基礎模型",
|
||||
"baseModelSearchPlaceholder": "搜尋基礎模型...",
|
||||
"modelTags": "標籤(前 20)",
|
||||
"modelTypes": "模型類型",
|
||||
"license": "授權",
|
||||
"noCreditRequired": "無需署名",
|
||||
"allowSellingGeneratedContent": "允許銷售",
|
||||
"noTags": "無標籤",
|
||||
"noBaseModelMatches": "沒有基礎模型符合目前的搜尋。",
|
||||
"clearAll": "清除所有篩選",
|
||||
"any": "任一",
|
||||
"all": "全部",
|
||||
|
||||
@@ -910,7 +910,7 @@ class ModelQueryHandler:
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
if limit < 1 or limit > 100:
|
||||
if limit < 0 or limit > 100:
|
||||
limit = 20
|
||||
base_models = await self._service.get_base_models(limit)
|
||||
return web.json_response({"success": True, "base_models": base_models})
|
||||
|
||||
@@ -329,6 +329,7 @@ class RecipeQueryHandler:
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
@@ -344,6 +345,8 @@ class RecipeQueryHandler:
|
||||
for model, count in base_model_counts.items()
|
||||
]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
if limit > 0:
|
||||
sorted_models = sorted_models[:limit]
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
|
||||
@@ -138,7 +138,7 @@ class Downloader:
|
||||
self.chunk_size = (
|
||||
16 * 1024 * 1024
|
||||
) # 16MB chunks to balance I/O reduction and memory usage
|
||||
self.max_retries = 5
|
||||
self.max_retries = self._resolve_max_retries()
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
self.stall_timeout = self._resolve_stall_timeout()
|
||||
@@ -192,6 +192,18 @@ class Downloader:
|
||||
|
||||
return max(30.0, timeout_value)
|
||||
|
||||
def _resolve_max_retries(self) -> int:
|
||||
"""Determine max retry count from environment while preserving defaults."""
|
||||
default_retries = 5
|
||||
raw_value = os.environ.get("COMFYUI_DOWNLOAD_MAX_RETRIES")
|
||||
|
||||
try:
|
||||
retries = int(raw_value)
|
||||
except (TypeError, ValueError):
|
||||
retries = default_retries
|
||||
|
||||
return max(0, retries)
|
||||
|
||||
def _should_refresh_session(self) -> bool:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
@@ -334,6 +346,7 @@ class Downloader:
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
total_size = 0
|
||||
range_redirect_retry_urls: set[str] = set()
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
@@ -372,6 +385,23 @@ class Downloader:
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
redirected_url = str(response.url)
|
||||
if (
|
||||
allow_resume
|
||||
and response.history
|
||||
and redirected_url
|
||||
and redirected_url != url
|
||||
and redirected_url not in range_redirect_retry_urls
|
||||
):
|
||||
range_redirect_retry_urls.add(redirected_url)
|
||||
logger.info(
|
||||
"Range request was not honored after redirect; retrying final URL directly: %s",
|
||||
redirected_url,
|
||||
)
|
||||
url = redirected_url
|
||||
response.release()
|
||||
continue
|
||||
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning(
|
||||
"Server doesn't support range requests, restarting download"
|
||||
@@ -571,37 +601,53 @@ class Downloader:
|
||||
expected_size = total_size if total_size > 0 else None
|
||||
|
||||
integrity_error: Optional[str] = None
|
||||
resumable_incomplete = False
|
||||
if final_size <= 0:
|
||||
integrity_error = "Downloaded file is empty"
|
||||
elif expected_size is not None and final_size != expected_size:
|
||||
integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
|
||||
resumable_incomplete = (
|
||||
allow_resume
|
||||
and part_path != save_path
|
||||
and final_size > 0
|
||||
and final_size < expected_size
|
||||
)
|
||||
|
||||
if integrity_error is not None:
|
||||
logger.error(
|
||||
log_fn = logger.warning if resumable_incomplete else logger.error
|
||||
log_fn(
|
||||
"Download integrity check failed for %s: %s",
|
||||
save_path,
|
||||
integrity_error,
|
||||
)
|
||||
|
||||
# Remove the corrupted payload so future attempts start fresh
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.remove(part_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete corrupted download %s: %s",
|
||||
part_path,
|
||||
remove_error,
|
||||
)
|
||||
if part_path != save_path and os.path.exists(save_path):
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete target file %s after integrity error: %s",
|
||||
save_path,
|
||||
remove_error,
|
||||
)
|
||||
if resumable_incomplete:
|
||||
logger.info(
|
||||
"Preserving incomplete download for resume: %s (%s/%s bytes)",
|
||||
part_path,
|
||||
final_size,
|
||||
expected_size,
|
||||
)
|
||||
else:
|
||||
# Remove corrupted payloads that cannot be safely resumed.
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.remove(part_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete corrupted download %s: %s",
|
||||
part_path,
|
||||
remove_error,
|
||||
)
|
||||
if part_path != save_path and os.path.exists(save_path):
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete target file %s after integrity error: %s",
|
||||
save_path,
|
||||
remove_error,
|
||||
)
|
||||
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
@@ -611,8 +657,16 @@ class Downloader:
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
resume_offset = 0
|
||||
total_size = 0
|
||||
if resumable_incomplete and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
total_size = expected_size or 0
|
||||
logger.info(
|
||||
"Will resume incomplete download from byte %s",
|
||||
resume_offset,
|
||||
)
|
||||
else:
|
||||
resume_offset = 0
|
||||
total_size = 0
|
||||
await self._create_session()
|
||||
continue
|
||||
|
||||
|
||||
@@ -1535,7 +1535,7 @@ class ModelScanner:
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get base models sorted by frequency"""
|
||||
"""Get base models sorted by count. If limit is 0, return all."""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
base_model_counts = {}
|
||||
@@ -1546,7 +1546,9 @@ class ModelScanner:
|
||||
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
|
||||
if limit == 0:
|
||||
return sorted_models
|
||||
return sorted_models[:limit]
|
||||
|
||||
async def get_model_info_by_name(self, name):
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
height: 100%;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
/* Responsive header container for larger screens */
|
||||
@@ -65,7 +66,6 @@
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
flex-shrink: 0;
|
||||
margin-right: 1rem;
|
||||
}
|
||||
|
||||
.nav-item {
|
||||
@@ -101,7 +101,6 @@
|
||||
.header-search {
|
||||
flex: 1;
|
||||
max-width: 400px;
|
||||
margin: 0 1rem;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
@@ -288,4 +287,4 @@
|
||||
.header-search {
|
||||
flex: 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@
|
||||
position: fixed;
|
||||
right: 20px;
|
||||
top: 50px; /* Position below header */
|
||||
width: 320px;
|
||||
width: 366px;
|
||||
background-color: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-base);
|
||||
@@ -197,6 +197,31 @@
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.filter-search-input {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
margin-bottom: 8px;
|
||||
padding: 8px 10px;
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.filter-search-input:focus {
|
||||
outline: none;
|
||||
border-color: var(--lora-accent);
|
||||
box-shadow: 0 0 0 2px rgba(var(--lora-accent-rgb, 76, 175, 80), 0.15);
|
||||
}
|
||||
|
||||
.filter-empty-state {
|
||||
margin-top: 8px;
|
||||
font-size: 13px;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.filter-section h4 {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 14px;
|
||||
@@ -733,4 +758,4 @@
|
||||
right: 20px;
|
||||
top: 160px; /* Adjusted for mobile layout */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,9 +240,7 @@ export class BulkManager {
|
||||
*/
|
||||
handleGlobalKeyboard(e) {
|
||||
// Skip if modal is open (handled by event manager conditions)
|
||||
// Skip if search input is focused
|
||||
const searchInput = document.getElementById('searchInput');
|
||||
if (searchInput && document.activeElement === searchInput) {
|
||||
if (this.isEditingTextInputContext(e.target)) {
|
||||
return false; // Don't handle, allow default behavior
|
||||
}
|
||||
|
||||
@@ -266,6 +264,26 @@ export class BulkManager {
|
||||
return false; // Continue with other handlers
|
||||
}
|
||||
|
||||
isEditingTextInputContext(target) {
|
||||
const activeElement = document.activeElement;
|
||||
const candidate = target instanceof Element ? target : activeElement;
|
||||
if (!candidate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const tagName = candidate.tagName?.toLowerCase();
|
||||
if (
|
||||
candidate.isContentEditable
|
||||
|| tagName === 'input'
|
||||
|| tagName === 'textarea'
|
||||
|| tagName === 'select'
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return Boolean(candidate.closest?.('#filterPanel'));
|
||||
}
|
||||
|
||||
toggleBulkMode() {
|
||||
state.bulkMode = !state.bulkMode;
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ export class FilterManager {
|
||||
this.filterPanel = document.getElementById('filterPanel');
|
||||
this.filterButton = document.getElementById('filterButton');
|
||||
this.activeFiltersCount = document.getElementById('activeFiltersCount');
|
||||
this.baseModelSearchInput = document.getElementById('baseModelSearchInput');
|
||||
this.baseModelOptions = [];
|
||||
this.tagsLoaded = false;
|
||||
|
||||
// Initialize preset manager
|
||||
@@ -49,6 +51,8 @@ export class FilterManager {
|
||||
}
|
||||
|
||||
initialize() {
|
||||
this.initializeFilterSearchInputs();
|
||||
|
||||
// Create base model filter tags if they exist
|
||||
if (document.getElementById('baseModelTags')) {
|
||||
this.createBaseModelTags();
|
||||
@@ -110,6 +114,18 @@ export class FilterManager {
|
||||
this.updateTagLogicToggleUI();
|
||||
}
|
||||
|
||||
initializeFilterSearchInputs() {
|
||||
if (this.baseModelSearchInput) {
|
||||
this.baseModelSearchInput.addEventListener('input', () => {
|
||||
this.renderBaseModelTags();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
getNormalizedSearchQuery(input) {
|
||||
return (input?.value || '').trim().toLowerCase();
|
||||
}
|
||||
|
||||
updateTagLogicToggleUI() {
|
||||
const toggleContainer = document.getElementById('tagLogicToggle');
|
||||
if (!toggleContainer) return;
|
||||
@@ -164,11 +180,6 @@ export class FilterManager {
|
||||
|
||||
tagsContainer.innerHTML = '';
|
||||
|
||||
if (!tags.length) {
|
||||
tagsContainer.innerHTML = `<div class="no-tags">No ${this.currentPage === 'recipes' ? 'recipe ' : ''}tags available</div>`;
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect existing tag names from the API response
|
||||
const existingTagNames = new Set(tags.map(t => t.tag));
|
||||
|
||||
@@ -186,6 +197,11 @@ export class FilterManager {
|
||||
});
|
||||
}
|
||||
|
||||
if (!tags.length) {
|
||||
tagsContainer.innerHTML = `<div class="no-tags">No ${this.currentPage === 'recipes' ? 'recipe ' : ''}tags available</div>`;
|
||||
return;
|
||||
}
|
||||
|
||||
tags.forEach(tag => {
|
||||
const tagEl = document.createElement('div');
|
||||
tagEl.className = 'filter-tag tag-filter';
|
||||
@@ -212,7 +228,6 @@ export class FilterManager {
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
this.applyTagElementState(tagEl, (this.filters.tags && this.filters.tags[tagName]) || 'none');
|
||||
tagsContainer.appendChild(tagEl);
|
||||
});
|
||||
|
||||
@@ -235,8 +250,8 @@ export class FilterManager {
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
this.applyTagElementState(noTagsEl, (this.filters.tags && this.filters.tags[noTagsKey]) || 'none');
|
||||
tagsContainer.appendChild(noTagsEl);
|
||||
this.updateTagSelections();
|
||||
}
|
||||
|
||||
initializeLicenseFilters() {
|
||||
@@ -323,44 +338,15 @@ export class FilterManager {
|
||||
if (!baseModelTagsContainer) return;
|
||||
|
||||
// Set the API endpoint based on current page
|
||||
const apiEndpoint = `/api/lm/${this.currentPage}/base-models`;
|
||||
const apiEndpoint = `/api/lm/${this.currentPage}/base-models?limit=0`;
|
||||
|
||||
// Fetch base models
|
||||
fetch(apiEndpoint)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success && data.base_models) {
|
||||
baseModelTagsContainer.innerHTML = '';
|
||||
|
||||
data.base_models.forEach(model => {
|
||||
const tag = document.createElement('div');
|
||||
tag.className = `filter-tag base-model-tag`;
|
||||
tag.dataset.baseModel = model.name;
|
||||
tag.innerHTML = `${model.name} <span class="tag-count">${model.count}</span>`;
|
||||
|
||||
// Add click handler to toggle selection and automatically apply
|
||||
tag.addEventListener('click', async () => {
|
||||
tag.classList.toggle('active');
|
||||
|
||||
if (tag.classList.contains('active')) {
|
||||
if (!this.filters.baseModel.includes(model.name)) {
|
||||
this.filters.baseModel.push(model.name);
|
||||
}
|
||||
} else {
|
||||
this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name);
|
||||
}
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
|
||||
// Auto-apply filter when tag is clicked
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
baseModelTagsContainer.appendChild(tag);
|
||||
});
|
||||
|
||||
// Update selections based on stored filters
|
||||
this.updateTagSelections();
|
||||
this.baseModelOptions = data.base_models;
|
||||
this.renderBaseModelTags();
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
@@ -369,6 +355,57 @@ export class FilterManager {
|
||||
});
|
||||
}
|
||||
|
||||
renderBaseModelTags() {
|
||||
const baseModelTagsContainer = document.getElementById('baseModelTags');
|
||||
const emptyState = document.getElementById('baseModelEmptyState');
|
||||
if (!baseModelTagsContainer) return;
|
||||
|
||||
baseModelTagsContainer.innerHTML = '';
|
||||
|
||||
if (!this.baseModelOptions.length) {
|
||||
baseModelTagsContainer.innerHTML = '<div class="no-tags">No base models available</div>';
|
||||
if (emptyState) {
|
||||
emptyState.hidden = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const query = this.getNormalizedSearchQuery(this.baseModelSearchInput);
|
||||
const filteredModels = query
|
||||
? this.baseModelOptions.filter(model => model.name.toLowerCase().includes(query))
|
||||
: this.baseModelOptions;
|
||||
|
||||
filteredModels.forEach(model => {
|
||||
const tag = document.createElement('div');
|
||||
tag.className = 'filter-tag base-model-tag';
|
||||
tag.dataset.baseModel = model.name;
|
||||
tag.innerHTML = `${model.name} <span class="tag-count">${model.count}</span>`;
|
||||
|
||||
tag.addEventListener('click', async () => {
|
||||
tag.classList.toggle('active');
|
||||
|
||||
if (tag.classList.contains('active')) {
|
||||
if (!this.filters.baseModel.includes(model.name)) {
|
||||
this.filters.baseModel.push(model.name);
|
||||
}
|
||||
} else {
|
||||
this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name);
|
||||
}
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
baseModelTagsContainer.appendChild(tag);
|
||||
});
|
||||
|
||||
if (emptyState) {
|
||||
emptyState.hidden = filteredModels.length > 0;
|
||||
}
|
||||
|
||||
this.updateTagSelections();
|
||||
}
|
||||
|
||||
async createModelTypeTags() {
|
||||
const modelTypeContainer = document.getElementById('modelTypeTags');
|
||||
if (!modelTypeContainer) return;
|
||||
@@ -453,6 +490,7 @@ export class FilterManager {
|
||||
|
||||
this.filterPanel.classList.remove('hidden');
|
||||
this.filterButton.classList.add('active');
|
||||
this.baseModelSearchInput?.focus();
|
||||
|
||||
// Load tags if they haven't been loaded yet
|
||||
if (!this.tagsLoaded) {
|
||||
|
||||
@@ -232,7 +232,7 @@ export class FilterPresetManager {
|
||||
|
||||
try {
|
||||
const fetchOptions = signal ? { signal } : {};
|
||||
const response = await fetch(`/api/lm/${this.currentPage}/base-models`, fetchOptions);
|
||||
const response = await fetch(`/api/lm/${this.currentPage}/base-models?limit=0`, fetchOptions);
|
||||
|
||||
if (!response.ok) throw new Error('Failed to fetch base models');
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ export function initializeEventManagement() {
|
||||
setupPageUnloadCleanup();
|
||||
|
||||
// Register global event handlers that need coordination
|
||||
registerGlobalEventHandlers();
|
||||
registerContextMenuEvents();
|
||||
registerGlobalClickHandlers();
|
||||
|
||||
@@ -148,6 +149,10 @@ function registerGlobalClickHandlers() {
|
||||
* Register common application-wide event handlers
|
||||
*/
|
||||
export function registerGlobalEventHandlers() {
|
||||
eventManager.removeHandler('keydown', 'global-escape');
|
||||
eventManager.removeHandler('focusin', 'global-focus');
|
||||
eventManager.removeHandler('click', 'global-analytics');
|
||||
|
||||
// Escape key handler for closing modals/panels
|
||||
eventManager.addHandler('keydown', 'global-escape', (e) => {
|
||||
if (e.key === 'Escape') {
|
||||
@@ -156,6 +161,14 @@ export function registerGlobalEventHandlers() {
|
||||
modalManager.closeCurrentModal();
|
||||
return true; // Stop propagation
|
||||
}
|
||||
|
||||
if (
|
||||
window.filterManager?.filterPanel
|
||||
&& !window.filterManager.filterPanel.classList.contains('hidden')
|
||||
) {
|
||||
window.filterManager.closeFilterPanel();
|
||||
return true; // Stop propagation
|
||||
}
|
||||
|
||||
// Check if node selector is active and close it
|
||||
if (eventManager.getState('nodeSelectorActive')) {
|
||||
|
||||
@@ -145,9 +145,22 @@
|
||||
|
||||
<div class="filter-section">
|
||||
<h4>{{ t('header.filter.baseModel') }}</h4>
|
||||
<input
|
||||
type="text"
|
||||
id="baseModelSearchInput"
|
||||
class="filter-search-input"
|
||||
placeholder="{{ t('header.filter.baseModelSearchPlaceholder') }}"
|
||||
autocomplete="off"
|
||||
autocorrect="off"
|
||||
autocapitalize="none"
|
||||
spellcheck="false"
|
||||
>
|
||||
<div class="filter-tags" id="baseModelTags">
|
||||
<!-- Tags will be dynamically inserted here -->
|
||||
</div>
|
||||
<div id="baseModelEmptyState" class="filter-empty-state" hidden>
|
||||
{{ t('header.filter.noBaseModelMatches') }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="filter-section">
|
||||
<div class="filter-section-header">
|
||||
@@ -188,4 +201,4 @@
|
||||
{{ t('header.filter.clearAll') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -110,7 +110,10 @@ function renderControlsDom(pageKey) {
|
||||
<div class="search-option-tag active" data-option="filename"></div>
|
||||
</div>
|
||||
<div id="filterPanel" class="filter-panel hidden">
|
||||
<input id="baseModelSearchInput" />
|
||||
<div id="baseModelTags" class="filter-tags"></div>
|
||||
<div id="baseModelEmptyState" hidden></div>
|
||||
<div id="filterPresets" class="filter-presets"></div>
|
||||
<div id="modelTagsFilter" class="filter-tags"></div>
|
||||
<button class="clear-filter"></button>
|
||||
</div>
|
||||
@@ -286,6 +289,8 @@ describe('FilterManager tag and base model filters', () => {
|
||||
|
||||
const manager = new FilterManager({ page: pageKey });
|
||||
|
||||
expect(global.fetch).toHaveBeenCalledWith(`/api/lm/${pageKey}/base-models?limit=0`);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const chip = document.querySelector('[data-base-model="SDXL"]');
|
||||
expect(chip).not.toBeNull();
|
||||
@@ -311,6 +316,259 @@ describe('FilterManager tag and base model filters', () => {
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual([]);
|
||||
expect(baseModelChip.classList.contains('active')).toBe(false);
|
||||
});
|
||||
|
||||
it('filters base model chips locally without changing selected state', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [
|
||||
{ name: 'SDXL', count: 2 },
|
||||
{ name: 'LTXV 2.3', count: 1 },
|
||||
],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { getCurrentPageState } = stateModule;
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
new FilterManager({ page: 'loras' });
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(document.querySelector('[data-base-model="LTXV 2.3"]')).not.toBeNull();
|
||||
});
|
||||
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
const ltxvChip = document.querySelector('[data-base-model="LTXV 2.3"]');
|
||||
ltxvChip.dispatchEvent(new Event('click', { bubbles: true }));
|
||||
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual(['LTXV 2.3']);
|
||||
|
||||
loadMoreWithVirtualScrollMock.mockClear();
|
||||
searchInput.value = 'sdx';
|
||||
searchInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
expect(document.querySelector('[data-base-model="SDXL"]')).not.toBeNull();
|
||||
expect(document.querySelector('[data-base-model="LTXV 2.3"]')).toBeNull();
|
||||
expect(document.getElementById('baseModelEmptyState').hidden).toBe(true);
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual(['LTXV 2.3']);
|
||||
|
||||
searchInput.value = 'zzz';
|
||||
searchInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
expect(document.getElementById('baseModelEmptyState').hidden).toBe(false);
|
||||
|
||||
searchInput.value = 'ltx';
|
||||
searchInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
const restoredChip = document.querySelector('[data-base-model="LTXV 2.3"]');
|
||||
expect(restoredChip).not.toBeNull();
|
||||
expect(restoredChip.classList.contains('active')).toBe(true);
|
||||
});
|
||||
|
||||
it('disables browser autocomplete helpers for the base model search input', async () => {
|
||||
renderControlsDom('loras');
|
||||
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
|
||||
searchInput.setAttribute('autocomplete', 'off');
|
||||
searchInput.setAttribute('autocorrect', 'off');
|
||||
searchInput.setAttribute('autocapitalize', 'none');
|
||||
searchInput.setAttribute('spellcheck', 'false');
|
||||
|
||||
expect(searchInput.getAttribute('autocomplete')).toBe('off');
|
||||
expect(searchInput.getAttribute('autocorrect')).toBe('off');
|
||||
expect(searchInput.getAttribute('autocapitalize')).toBe('none');
|
||||
expect(searchInput.getAttribute('spellcheck')).toBe('false');
|
||||
});
|
||||
|
||||
it('focuses the base model search input when opening the filter panel', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL', count: 2 }],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
const manager = new FilterManager({ page: 'loras' });
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
|
||||
expect(document.activeElement).not.toBe(searchInput);
|
||||
|
||||
manager.toggleFilterPanel();
|
||||
|
||||
expect(document.activeElement).toBe(searchInput);
|
||||
});
|
||||
|
||||
it('does not let base model search trigger bulk shortcuts', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL', count: 2 }],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { BulkManager } = await import('../../../static/js/managers/BulkManager.js');
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
const filterManager = new FilterManager({ page: 'loras' });
|
||||
const bulkManager = new BulkManager();
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
window.filterManager = filterManager;
|
||||
|
||||
searchInput.focus();
|
||||
|
||||
const bulkEvent = new KeyboardEvent('keydown', {
|
||||
key: 'b',
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
});
|
||||
Object.defineProperty(bulkEvent, 'target', { value: searchInput });
|
||||
expect(bulkManager.handleGlobalKeyboard(bulkEvent)).toBe(false);
|
||||
|
||||
const selectAllEvent = new KeyboardEvent('keydown', {
|
||||
key: 'a',
|
||||
ctrlKey: true,
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
});
|
||||
Object.defineProperty(selectAllEvent, 'target', { value: searchInput });
|
||||
expect(bulkManager.handleGlobalKeyboard(selectAllEvent)).toBe(false);
|
||||
});
|
||||
|
||||
it('closes the filter panel on Escape', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL', count: 2 }],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
const { eventManager } = await import('../../../static/js/utils/EventManager.js');
|
||||
const { initializeEventManagement } = await import('../../../static/js/utils/eventManagementInit.js');
|
||||
|
||||
eventManager.cleanup();
|
||||
initializeEventManagement();
|
||||
|
||||
const manager = new FilterManager({ page: 'loras' });
|
||||
window.filterManager = manager;
|
||||
manager.toggleFilterPanel();
|
||||
expect(manager.filterPanel.classList.contains('hidden')).toBe(false);
|
||||
|
||||
document.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape', bubbles: true }));
|
||||
|
||||
expect(manager.filterPanel.classList.contains('hidden')).toBe(true);
|
||||
eventManager.cleanup();
|
||||
});
|
||||
|
||||
it('applies all base models from a preset using the full base model list', async () => {
|
||||
global.fetch = vi.fn((url) => {
|
||||
if (url.includes('/base-models?limit=0')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [
|
||||
{ name: 'SDXL 1.0', count: 5 },
|
||||
{ name: 'SDXL Lightning', count: 3 },
|
||||
{ name: 'SDXL Hyper', count: 2 },
|
||||
],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.includes('/base-models')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL 1.0', count: 5 }],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.includes('/top-tags')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
tags: [],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.includes('/model-types')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
model_types: [],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ success: true }),
|
||||
});
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
stateModule.state.global.settings.filter_presets = {
|
||||
loras: [
|
||||
{
|
||||
name: 'SDXL Family',
|
||||
filters: {
|
||||
baseModel: ['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper'],
|
||||
tags: {},
|
||||
license: {},
|
||||
modelTypes: [],
|
||||
tagLogic: 'any',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const { getCurrentPageState } = stateModule;
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
const manager = new FilterManager({ page: 'loras' });
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(document.querySelector('[data-base-model="SDXL Hyper"]')).not.toBeNull();
|
||||
});
|
||||
|
||||
await manager.presetManager.applyPreset('SDXL Family');
|
||||
|
||||
expect(manager.activePreset).toBe('SDXL Family');
|
||||
expect(manager.filters.baseModel).toEqual(['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper']);
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual(['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper']);
|
||||
expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledWith(true, false);
|
||||
expect(showToastMock).toHaveBeenCalledWith(
|
||||
'Preset "SDXL Family" applied',
|
||||
{},
|
||||
'success',
|
||||
);
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
describe('PageControls favorites, sorting, and duplicates scenarios', () => {
|
||||
|
||||
38
tests/routes/test_model_query_handler.py
Normal file
38
tests/routes/test_model_query_handler.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.handlers.model_handlers import ModelQueryHandler
|
||||
|
||||
|
||||
class DummyService:
|
||||
def __init__(self):
|
||||
self.received_limit = None
|
||||
|
||||
async def get_base_models(self, limit):
|
||||
self.received_limit = limit
|
||||
return [{"name": "SDXL", "count": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_query_handler_accepts_limit_zero_for_base_models():
|
||||
service = DummyService()
|
||||
handler = ModelQueryHandler(service=service, logger=logging.getLogger(__name__))
|
||||
|
||||
response = await handler.get_base_models(SimpleNamespace(query={"limit": "0"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert service.received_limit == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_query_handler_rejects_negative_limit_for_base_models():
|
||||
service = DummyService()
|
||||
handler = ModelQueryHandler(service=service, logger=logging.getLogger(__name__))
|
||||
|
||||
await handler.get_base_models(SimpleNamespace(query={"limit": "-1"}))
|
||||
|
||||
assert service.received_limit == 20
|
||||
44
tests/routes/test_recipe_query_handler.py
Normal file
44
tests/routes/test_recipe_query_handler.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.handlers.recipe_handlers import RecipeQueryHandler
|
||||
|
||||
|
||||
async def _noop():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recipe_query_handler_base_models_limit_zero_returns_all():
|
||||
cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": "LTXV 2.3"},
|
||||
{"base_model": "SDXL"},
|
||||
]
|
||||
)
|
||||
scanner = SimpleNamespace(get_cached_data=lambda: None)
|
||||
|
||||
async def get_cached_data():
|
||||
return cache
|
||||
|
||||
scanner.get_cached_data = get_cached_data
|
||||
|
||||
handler = RecipeQueryHandler(
|
||||
ensure_dependencies_ready=_noop,
|
||||
recipe_scanner_getter=lambda: scanner,
|
||||
format_recipe_file_url=lambda value: value,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
response = await handler.get_base_models(SimpleNamespace(query={"limit": "0"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["base_models"] == [
|
||||
{"name": "SDXL", "count": 2},
|
||||
{"name": "LTXV 2.3", "count": 1},
|
||||
]
|
||||
@@ -30,10 +30,21 @@ class FakeStream:
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status, headers, chunks):
|
||||
def __init__(
|
||||
self,
|
||||
status,
|
||||
headers,
|
||||
chunks,
|
||||
*,
|
||||
url="https://example.com/file",
|
||||
history=None,
|
||||
):
|
||||
self.status = status
|
||||
self.headers = headers
|
||||
self.content = FakeStream(chunks)
|
||||
self.url = url
|
||||
self.history = history or []
|
||||
self.released = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@@ -41,14 +52,25 @@ class FakeResponse:
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
self.released = True
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, responses):
|
||||
self._responses = list(responses)
|
||||
self._get_calls = 0
|
||||
self.requests = []
|
||||
|
||||
def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp
|
||||
del url, headers, allow_redirects, proxy
|
||||
self.requests.append(
|
||||
{
|
||||
"url": url,
|
||||
"headers": headers or {},
|
||||
"allow_redirects": allow_redirects,
|
||||
"proxy": proxy,
|
||||
}
|
||||
)
|
||||
response_factory = self._responses[self._get_calls]
|
||||
self._get_calls += 1
|
||||
return response_factory()
|
||||
@@ -75,7 +97,7 @@ def _build_downloader(responses, *, max_retries=0):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_fails_when_size_mismatch(tmp_path):
|
||||
async def test_download_file_preserves_incomplete_part_when_size_mismatch(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
@@ -94,7 +116,7 @@ async def test_download_file_fails_when_size_mismatch(tmp_path):
|
||||
assert success is False
|
||||
assert "mismatch" in message.lower()
|
||||
assert not target_path.exists()
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
assert Path(str(target_path) + ".part").read_bytes() == b"abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -136,7 +158,9 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
|
||||
|
||||
downloader = _build_downloader(responses)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
success, result_path = await downloader.download_file(
|
||||
"https://example.com/file", str(target_path)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
@@ -166,9 +190,77 @@ async def test_download_file_recovers_from_stall(tmp_path):
|
||||
downloader = _build_downloader(responses, max_retries=1)
|
||||
downloader.stall_timeout = 0.05
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
success, result_path = await downloader.download_file(
|
||||
"https://example.com/file", str(target_path)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
assert downloader._session._get_calls == 2
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_resumes_after_incomplete_integrity_check(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
responses = [
|
||||
lambda: FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": "6"},
|
||||
chunks=[b"abc"],
|
||||
),
|
||||
lambda: FakeResponse(
|
||||
status=206,
|
||||
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
|
||||
chunks=[b"def"],
|
||||
),
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses, max_retries=1)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == b"abcdef"
|
||||
assert downloader._session._get_calls == 2
|
||||
assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_retries_redirected_url_when_range_not_honored(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
Path(str(target_path) + ".part").write_bytes(b"abc")
|
||||
|
||||
redirected_url = "https://download.example.com/file.bin"
|
||||
first_response = FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": "6"},
|
||||
chunks=[],
|
||||
url=redirected_url,
|
||||
history=[object()],
|
||||
)
|
||||
|
||||
responses = [
|
||||
lambda: first_response,
|
||||
lambda: FakeResponse(
|
||||
status=206,
|
||||
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
|
||||
chunks=[b"def"],
|
||||
url=redirected_url,
|
||||
),
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses, max_retries=0)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == b"abcdef"
|
||||
assert first_response.released is True
|
||||
assert downloader._session.requests[0]["headers"]["Range"] == "bytes=3-"
|
||||
assert downloader._session.requests[1]["url"] == redirected_url
|
||||
assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"
|
||||
|
||||
52
tests/services/test_model_scanner_base_models.py
Normal file
52
tests/services/test_model_scanner_base_models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.model_scanner import ModelScanner
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, raw_data):
|
||||
self._cache = SimpleNamespace(raw_data=raw_data)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_base_models_limit_zero_returns_all_sorted():
|
||||
scanner = DummyScanner(
|
||||
[
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": "LTXV 2.3"},
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": ""},
|
||||
{},
|
||||
]
|
||||
)
|
||||
|
||||
result = await ModelScanner.get_base_models(scanner, limit=0)
|
||||
|
||||
assert result == [
|
||||
{"name": "SDXL", "count": 2},
|
||||
{"name": "LTXV 2.3", "count": 1},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_base_models_positive_limit_still_truncates():
|
||||
scanner = DummyScanner(
|
||||
[
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": "LTXV 2.3"},
|
||||
{"base_model": "Flux.1 D"},
|
||||
{"base_model": "SDXL"},
|
||||
]
|
||||
)
|
||||
|
||||
result = await ModelScanner.get_base_models(scanner, limit=2)
|
||||
|
||||
assert result == [
|
||||
{"name": "SDXL", "count": 2},
|
||||
{"name": "LTXV 2.3", "count": 1},
|
||||
]
|
||||
Reference in New Issue
Block a user