mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-07 00:46:44 -03:00
Compare commits
17 Commits
v1.0.5
...
1eeba666f5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1eeba666f5 | ||
|
|
89e26d9292 | ||
|
|
fc19a145ff | ||
|
|
34f03d6495 | ||
|
|
9443175abc | ||
|
|
dc5072628f | ||
|
|
ff4b8ec849 | ||
|
|
7ab271c752 | ||
|
|
5a7f4dc88b | ||
|
|
761108bfd1 | ||
|
|
24dd3a777c | ||
|
|
1c530ea013 | ||
|
|
0ced53c059 | ||
|
|
67ad68a23f | ||
|
|
d9ec9c512e | ||
|
|
0bcd8e09a9 | ||
|
|
fa049a28c8 |
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Voreinstellung \"{name}\" existiert bereits. Überschreiben?",
|
||||
"presetNamePlaceholder": "Voreinstellungsname...",
|
||||
"baseModel": "Basis-Modell",
|
||||
"baseModelSearchPlaceholder": "Basismodelle durchsuchen...",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Modelltypen",
|
||||
"license": "Lizenz",
|
||||
"noCreditRequired": "Kein Credit erforderlich",
|
||||
"allowSellingGeneratedContent": "Verkauf erlaubt",
|
||||
"noTags": "Keine Tags",
|
||||
"noBaseModelMatches": "Keine Basismodelle entsprechen der aktuellen Suche.",
|
||||
"clearAll": "Alle Filter löschen",
|
||||
"any": "Beliebig",
|
||||
"all": "Alle",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red (uneingeschränkt)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "Download-Backend",
|
||||
"help": "Wähle aus, wie Modelldateien heruntergeladen werden. Python verwendet den eingebauten Downloader. aria2 verwendet den experimentellen externen Downloader-Prozess.",
|
||||
"options": {
|
||||
"python": "Python (integriert)",
|
||||
"aria2": "aria2 (experimentell)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "aria2c-Pfad",
|
||||
"help": "Optionaler Pfad zur ausführbaren aria2c-Datei. Leer lassen, um aria2c aus dem System-PATH zu verwenden.",
|
||||
"placeholder": "Leer lassen, um aria2c aus dem PATH zu verwenden"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Civitai-Host-Einstellung verfügbar",
|
||||
"content": "Civitai verwendet jetzt civitai.com für SFW-Inhalte und civitai.red für uneingeschränkte Inhalte. In den Einstellungen können Sie ändern, welche Seite standardmäßig geöffnet wird.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "Inhaltsfilterung",
|
||||
"downloads": "Downloads",
|
||||
"videoSettings": "Video-Einstellungen",
|
||||
"layoutSettings": "Layout-Einstellungen",
|
||||
"misc": "Verschiedenes",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Preset \"{name}\" already exists. Overwrite?",
|
||||
"presetNamePlaceholder": "Preset name...",
|
||||
"baseModel": "Base Model",
|
||||
"baseModelSearchPlaceholder": "Search base models...",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "License",
|
||||
"noCreditRequired": "No Credit Required",
|
||||
"allowSellingGeneratedContent": "Allow Selling",
|
||||
"noTags": "No tags",
|
||||
"noBaseModelMatches": "No base models match the current search.",
|
||||
"clearAll": "Clear All Filters",
|
||||
"any": "Any",
|
||||
"all": "All",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red (unrestricted)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "Download backend",
|
||||
"help": "Choose how model files are downloaded. Python uses the built-in downloader. aria2 uses the experimental external downloader process.",
|
||||
"options": {
|
||||
"python": "Python (built-in)",
|
||||
"aria2": "aria2 (experimental)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "aria2c path",
|
||||
"help": "Optional path to the aria2c executable. Leave empty to use aria2c from your system PATH.",
|
||||
"placeholder": "Leave empty to use aria2c from PATH"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Civitai host preference available",
|
||||
"content": "Civitai now uses civitai.com for SFW content and civitai.red for unrestricted content. You can change which site opens by default in Settings.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "Content Filtering",
|
||||
"downloads": "Downloads",
|
||||
"videoSettings": "Video Settings",
|
||||
"layoutSettings": "Layout Settings",
|
||||
"misc": "Miscellaneous",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "El preset \"{name}\" ya existe. ¿Sobrescribir?",
|
||||
"presetNamePlaceholder": "Nombre del preajuste...",
|
||||
"baseModel": "Modelo base",
|
||||
"baseModelSearchPlaceholder": "Buscar modelos base...",
|
||||
"modelTags": "Etiquetas (Top 20)",
|
||||
"modelTypes": "Tipos de modelos",
|
||||
"license": "Licencia",
|
||||
"noCreditRequired": "Sin crédito requerido",
|
||||
"allowSellingGeneratedContent": "Venta permitida",
|
||||
"noTags": "Sin etiquetas",
|
||||
"noBaseModelMatches": "Ningún modelo base coincide con la búsqueda actual.",
|
||||
"clearAll": "Limpiar todos los filtros",
|
||||
"any": "Cualquiera",
|
||||
"all": "Todos",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red (sin restricciones)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "Backend de descarga",
|
||||
"help": "Elige cómo se descargan los archivos del modelo. Python usa el descargador integrado. aria2 usa el proceso externo experimental de descarga.",
|
||||
"options": {
|
||||
"python": "Python (integrado)",
|
||||
"aria2": "aria2 (experimental)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "Ruta de aria2c",
|
||||
"help": "Ruta opcional al ejecutable aria2c. Déjalo vacío para usar aria2c desde el PATH del sistema.",
|
||||
"placeholder": "Déjalo vacío para usar aria2c desde el PATH"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Preferencia de host de Civitai disponible",
|
||||
"content": "Civitai ahora usa civitai.com para contenido SFW y civitai.red para contenido sin restricciones. Puedes cambiar en Ajustes qué sitio se abre por defecto.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "Filtrado de contenido",
|
||||
"downloads": "Descargas",
|
||||
"videoSettings": "Configuración de video",
|
||||
"layoutSettings": "Configuración de diseño",
|
||||
"misc": "Varios",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Le préréglage \"{name}\" existe déjà. Remplacer?",
|
||||
"presetNamePlaceholder": "Nom du préréglage...",
|
||||
"baseModel": "Modèle de base",
|
||||
"baseModelSearchPlaceholder": "Rechercher des modèles de base...",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Types de modèles",
|
||||
"license": "Licence",
|
||||
"noCreditRequired": "Crédit non requis",
|
||||
"allowSellingGeneratedContent": "Vente autorisée",
|
||||
"noTags": "Aucun tag",
|
||||
"noBaseModelMatches": "Aucun modèle de base ne correspond à la recherche actuelle.",
|
||||
"clearAll": "Effacer tous les filtres",
|
||||
"any": "N'importe quel",
|
||||
"all": "Tous",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red (sans restriction)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "Moteur de téléchargement",
|
||||
"help": "Choisissez comment les fichiers de modèles sont téléchargés. Python utilise le téléchargeur intégré. aria2 utilise le processus externe expérimental de téléchargement.",
|
||||
"options": {
|
||||
"python": "Python (intégré)",
|
||||
"aria2": "aria2 (expérimental)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "Chemin vers aria2c",
|
||||
"help": "Chemin facultatif vers l’exécutable aria2c. Laissez vide pour utiliser aria2c depuis le PATH système.",
|
||||
"placeholder": "Laisser vide pour utiliser aria2c depuis le PATH"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Préférence d’hôte Civitai disponible",
|
||||
"content": "Civitai utilise désormais civitai.com pour le contenu SFW et civitai.red pour le contenu sans restriction. Vous pouvez modifier dans les paramètres le site ouvert par défaut.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "Filtrage du contenu",
|
||||
"downloads": "Téléchargements",
|
||||
"videoSettings": "Paramètres vidéo",
|
||||
"layoutSettings": "Paramètres d'affichage",
|
||||
"misc": "Divers",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "הפריסט \"{name}\" כבר קיים. לדרוס?",
|
||||
"presetNamePlaceholder": "שם קביעה מראש...",
|
||||
"baseModel": "מודל בסיס",
|
||||
"baseModelSearchPlaceholder": "חפש מודלי בסיס...",
|
||||
"modelTags": "תגיות (20 המובילות)",
|
||||
"modelTypes": "סוגי מודלים",
|
||||
"license": "רישיון",
|
||||
"noCreditRequired": "ללא קרדיט נדרש",
|
||||
"allowSellingGeneratedContent": "אפשר מכירה",
|
||||
"noTags": "ללא תגיות",
|
||||
"noBaseModelMatches": "אין מודלי בסיס התואמים לחיפוש הנוכחי.",
|
||||
"clearAll": "נקה את כל המסננים",
|
||||
"any": "כלשהו",
|
||||
"all": "כל התגים",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red (ללא הגבלות)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "מנגנון הורדה",
|
||||
"help": "בחר כיצד יורדים קבצי המודל. Python משתמש במוריד המובנה. aria2 משתמש בתהליך הורדה חיצוני ניסיוני.",
|
||||
"options": {
|
||||
"python": "Python (מובנה)",
|
||||
"aria2": "aria2 (ניסיוני)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "נתיב aria2c",
|
||||
"help": "נתיב אופציונלי לקובץ ההפעלה aria2c. השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH של המערכת.",
|
||||
"placeholder": "השאר ריק כדי להשתמש ב-aria2c מתוך ה-PATH"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "העדפת מארח Civitai זמינה",
|
||||
"content": "Civitai משתמש כעת ב-civitai.com עבור תוכן SFW וב-civitai.red עבור תוכן ללא הגבלות. ניתן לשנות בהגדרות איזה אתר ייפתח כברירת מחדל.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "סינון תוכן",
|
||||
"downloads": "הורדות",
|
||||
"videoSettings": "הגדרות וידאו",
|
||||
"layoutSettings": "הגדרות פריסה",
|
||||
"misc": "שונות",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "プリセット「{name}」は既に存在します。上書きしますか?",
|
||||
"presetNamePlaceholder": "プリセット名...",
|
||||
"baseModel": "ベースモデル",
|
||||
"baseModelSearchPlaceholder": "ベースモデルを検索...",
|
||||
"modelTags": "タグ(上位20)",
|
||||
"modelTypes": "モデルタイプ",
|
||||
"license": "ライセンス",
|
||||
"noCreditRequired": "クレジット不要",
|
||||
"allowSellingGeneratedContent": "販売許可",
|
||||
"noTags": "タグなし",
|
||||
"noBaseModelMatches": "現在の検索に一致するベースモデルはありません。",
|
||||
"clearAll": "すべてのフィルタをクリア",
|
||||
"any": "いずれか",
|
||||
"all": "すべて",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red(制限なし)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "ダウンロードバックエンド",
|
||||
"help": "モデルファイルのダウンロード方法を選択します。Python は内蔵ダウンローダーを使用し、aria2 は実験的な外部ダウンローダープロセスを使用します。",
|
||||
"options": {
|
||||
"python": "Python(内蔵)",
|
||||
"aria2": "aria2(実験的)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "aria2c のパス",
|
||||
"help": "aria2c 実行ファイルへの任意のパスです。空欄のままにすると、システム PATH 上の aria2c を使用します。",
|
||||
"placeholder": "空欄のままにすると PATH 上の aria2c を使用します"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Civitai ホスト設定を利用できます",
|
||||
"content": "Civitai は現在、SFW コンテンツには civitai.com、制限なしコンテンツには civitai.red を使用しています。設定で既定で開くサイトを変更できます。",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "コンテンツフィルタリング",
|
||||
"downloads": "ダウンロード",
|
||||
"videoSettings": "動画設定",
|
||||
"layoutSettings": "レイアウト設定",
|
||||
"misc": "その他",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "프리셋 \"{name}\"이(가) 이미 존재합니다. 덮어쓰시겠습니까?",
|
||||
"presetNamePlaceholder": "프리셋 이름...",
|
||||
"baseModel": "베이스 모델",
|
||||
"baseModelSearchPlaceholder": "베이스 모델 검색...",
|
||||
"modelTags": "태그 (상위 20개)",
|
||||
"modelTypes": "모델 유형",
|
||||
"license": "라이선스",
|
||||
"noCreditRequired": "크레딧 표기 없음",
|
||||
"allowSellingGeneratedContent": "판매 허용",
|
||||
"noTags": "태그 없음",
|
||||
"noBaseModelMatches": "현재 검색과 일치하는 베이스 모델이 없습니다.",
|
||||
"clearAll": "모든 필터 지우기",
|
||||
"any": "아무",
|
||||
"all": "모두",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red(무제한)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "다운로드 백엔드",
|
||||
"help": "모델 파일을 다운로드하는 방식을 선택합니다. Python은 내장 다운로더를 사용하고, aria2는 실험적인 외부 다운로더 프로세스를 사용합니다.",
|
||||
"options": {
|
||||
"python": "Python(내장)",
|
||||
"aria2": "aria2(실험적)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "aria2c 경로",
|
||||
"help": "aria2c 실행 파일의 선택적 경로입니다. 비워 두면 시스템 PATH의 aria2c를 사용합니다.",
|
||||
"placeholder": "비워 두면 PATH의 aria2c를 사용합니다"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Civitai 호스트 기본 설정 사용 가능",
|
||||
"content": "이제 Civitai는 SFW 콘텐츠에 civitai.com을, 무제한 콘텐츠에 civitai.red를 사용합니다. 설정에서 기본으로 열 사이트를 변경할 수 있습니다.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "콘텐츠 필터링",
|
||||
"downloads": "다운로드",
|
||||
"videoSettings": "비디오 설정",
|
||||
"layoutSettings": "레이아웃 설정",
|
||||
"misc": "기타",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "Пресет \"{name}\" уже существует. Перезаписать?",
|
||||
"presetNamePlaceholder": "Имя пресета...",
|
||||
"baseModel": "Базовая модель",
|
||||
"baseModelSearchPlaceholder": "Поиск базовых моделей...",
|
||||
"modelTags": "Теги (Топ 20)",
|
||||
"modelTypes": "Типы моделей",
|
||||
"license": "Лицензия",
|
||||
"noCreditRequired": "Без указания авторства",
|
||||
"allowSellingGeneratedContent": "Продажа разрешена",
|
||||
"noTags": "Без тегов",
|
||||
"noBaseModelMatches": "Нет базовых моделей, соответствующих текущему поиску.",
|
||||
"clearAll": "Очистить все фильтры",
|
||||
"any": "Любой",
|
||||
"all": "Все",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red (без ограничений)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "Бэкенд загрузки",
|
||||
"help": "Выберите способ загрузки файлов моделей. Python использует встроенный загрузчик. aria2 использует экспериментальный внешний процесс загрузки.",
|
||||
"options": {
|
||||
"python": "Python (встроенный)",
|
||||
"aria2": "aria2 (экспериментальный)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "Путь к aria2c",
|
||||
"help": "Необязательный путь к исполняемому файлу aria2c. Оставьте пустым, чтобы использовать aria2c из системного PATH.",
|
||||
"placeholder": "Оставьте пустым, чтобы использовать aria2c из PATH"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "Доступна настройка хоста Civitai",
|
||||
"content": "Теперь Civitai использует civitai.com для контента SFW и civitai.red для контента без ограничений. В настройках можно изменить, какой сайт открывать по умолчанию.",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "Фильтрация контента",
|
||||
"downloads": "Загрузки",
|
||||
"videoSettings": "Настройки видео",
|
||||
"layoutSettings": "Настройки макета",
|
||||
"misc": "Разное",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "预设 \"{name}\" 已存在。是否覆盖?",
|
||||
"presetNamePlaceholder": "预设名称...",
|
||||
"baseModel": "基础模型",
|
||||
"baseModelSearchPlaceholder": "搜索基础模型...",
|
||||
"modelTags": "标签(前20)",
|
||||
"modelTypes": "模型类型",
|
||||
"license": "许可证",
|
||||
"noCreditRequired": "无需署名",
|
||||
"allowSellingGeneratedContent": "允许销售",
|
||||
"noTags": "无标签",
|
||||
"noBaseModelMatches": "没有基础模型符合当前搜索。",
|
||||
"clearAll": "清除所有筛选",
|
||||
"any": "任一",
|
||||
"all": "全部",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red(无限制)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "下载后端",
|
||||
"help": "选择模型文件的下载方式。Python 使用内置下载器。aria2 使用实验性的外部下载进程。",
|
||||
"options": {
|
||||
"python": "Python(内置)",
|
||||
"aria2": "aria2(实验性)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "aria2c 路径",
|
||||
"help": "可选的 aria2c 可执行文件路径。留空则使用系统 PATH 中的 aria2c。",
|
||||
"placeholder": "留空则使用 PATH 中的 aria2c"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "已提供 Civitai 站点偏好设置",
|
||||
"content": "Civitai 现在使用 civitai.com 提供 SFW 内容,使用 civitai.red 提供无限制内容。你可以在设置中更改默认打开的站点。",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "内容过滤",
|
||||
"downloads": "下载",
|
||||
"videoSettings": "视频设置",
|
||||
"layoutSettings": "布局设置",
|
||||
"misc": "其他",
|
||||
|
||||
@@ -225,12 +225,14 @@
|
||||
"presetOverwriteConfirm": "預設 \"{name}\" 已存在。是否覆蓋?",
|
||||
"presetNamePlaceholder": "預設名稱...",
|
||||
"baseModel": "基礎模型",
|
||||
"baseModelSearchPlaceholder": "搜尋基礎模型...",
|
||||
"modelTags": "標籤(前 20)",
|
||||
"modelTypes": "模型類型",
|
||||
"license": "授權",
|
||||
"noCreditRequired": "無需署名",
|
||||
"allowSellingGeneratedContent": "允許銷售",
|
||||
"noTags": "無標籤",
|
||||
"noBaseModelMatches": "沒有基礎模型符合目前的搜尋。",
|
||||
"clearAll": "清除所有篩選",
|
||||
"any": "任一",
|
||||
"all": "全部",
|
||||
@@ -261,6 +263,19 @@
|
||||
"red": "civitai.red(無限制)"
|
||||
}
|
||||
},
|
||||
"downloadBackend": {
|
||||
"label": "下載後端",
|
||||
"help": "選擇模型檔案的下載方式。Python 使用內建下載器。aria2 使用實驗性的外部下載程序。",
|
||||
"options": {
|
||||
"python": "Python(內建)",
|
||||
"aria2": "aria2(實驗性)"
|
||||
}
|
||||
},
|
||||
"aria2cPath": {
|
||||
"label": "aria2c 路徑",
|
||||
"help": "可選的 aria2c 可執行檔路徑。留空則使用系統 PATH 中的 aria2c。",
|
||||
"placeholder": "留空則使用 PATH 中的 aria2c"
|
||||
},
|
||||
"civitaiHostBanner": {
|
||||
"title": "已提供 Civitai 站點偏好設定",
|
||||
"content": "Civitai 現在使用 civitai.com 提供 SFW 內容,使用 civitai.red 提供無限制內容。你可以在設定中變更預設開啟的站點。",
|
||||
@@ -276,6 +291,7 @@
|
||||
},
|
||||
"sections": {
|
||||
"contentFiltering": "內容過濾",
|
||||
"downloads": "下載",
|
||||
"videoSettings": "影片設定",
|
||||
"layoutSettings": "版面設定",
|
||||
"misc": "其他",
|
||||
|
||||
77
py/config.py
77
py/config.py
@@ -26,20 +26,44 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_valid_default_root(
|
||||
current: str, primary_paths: List[str], name: str
|
||||
current: str, primary_paths: List[str], allowed_paths: List[str], name: str
|
||||
) -> str:
|
||||
"""Return a valid default root from the current primary path set."""
|
||||
"""Return a valid default root from the current primary/extra path set."""
|
||||
|
||||
valid_paths = [path for path in primary_paths if isinstance(path, str) and path.strip()]
|
||||
if not valid_paths:
|
||||
return ""
|
||||
fallback_paths: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
for path in allowed_paths:
|
||||
if not isinstance(path, str):
|
||||
continue
|
||||
stripped = path.strip()
|
||||
if not stripped or stripped in seen:
|
||||
continue
|
||||
seen.add(stripped)
|
||||
fallback_paths.append(stripped)
|
||||
|
||||
if current in valid_paths:
|
||||
allowed = set(fallback_paths)
|
||||
|
||||
if current and current in allowed:
|
||||
return current
|
||||
|
||||
if not valid_paths:
|
||||
if not fallback_paths:
|
||||
return ""
|
||||
if current:
|
||||
logger.info(
|
||||
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||
name,
|
||||
current,
|
||||
fallback_paths[0],
|
||||
)
|
||||
else:
|
||||
logger.info("Auto-setting %s to '%s'", name, fallback_paths[0])
|
||||
return fallback_paths[0]
|
||||
|
||||
if current:
|
||||
logger.info(
|
||||
"Repaired stale %s from '%s' to '%s'",
|
||||
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||
name,
|
||||
current,
|
||||
valid_paths[0],
|
||||
@@ -226,39 +250,76 @@ class Config:
|
||||
default_lora_root = _resolve_valid_default_root(
|
||||
comfy_library.get("default_lora_root", ""),
|
||||
list(self.loras_roots or []),
|
||||
list(self.loras_roots or [])
|
||||
+ list(comfy_library.get("extra_folder_paths", {}).get("loras", []) or []),
|
||||
"default_lora_root",
|
||||
)
|
||||
|
||||
default_checkpoint_root = _resolve_valid_default_root(
|
||||
comfy_library.get("default_checkpoint_root", ""),
|
||||
list(self.checkpoints_roots or []),
|
||||
list(self.checkpoints_roots or [])
|
||||
+ list(comfy_library.get("extra_folder_paths", {}).get("checkpoints", []) or []),
|
||||
"default_checkpoint_root",
|
||||
)
|
||||
|
||||
default_embedding_root = _resolve_valid_default_root(
|
||||
comfy_library.get("default_embedding_root", ""),
|
||||
list(self.embeddings_roots or []),
|
||||
list(self.embeddings_roots or [])
|
||||
+ list(comfy_library.get("extra_folder_paths", {}).get("embeddings", []) or []),
|
||||
"default_embedding_root",
|
||||
)
|
||||
|
||||
metadata = dict(comfy_library.get("metadata", {}))
|
||||
metadata.setdefault("display_name", "ComfyUI")
|
||||
metadata["source"] = "comfyui"
|
||||
extra_folder_paths = {}
|
||||
if isinstance(comfy_library, Mapping):
|
||||
existing_extra_paths = comfy_library.get("extra_folder_paths", {})
|
||||
if isinstance(existing_extra_paths, Mapping):
|
||||
extra_folder_paths = {
|
||||
key: list(value) if isinstance(value, list) else []
|
||||
for key, value in existing_extra_paths.items()
|
||||
}
|
||||
|
||||
active_library_name = settings_service.get_active_library_name()
|
||||
should_activate = (
|
||||
active_library_name == "comfyui"
|
||||
or self._should_activate_comfy_library(libraries, libraries_changed)
|
||||
)
|
||||
|
||||
settings_service.upsert_library(
|
||||
"comfyui",
|
||||
folder_paths=target_folder_paths,
|
||||
extra_folder_paths=extra_folder_paths,
|
||||
default_lora_root=default_lora_root,
|
||||
default_checkpoint_root=default_checkpoint_root,
|
||||
default_embedding_root=default_embedding_root,
|
||||
metadata=metadata,
|
||||
activate=True,
|
||||
activate=should_activate,
|
||||
)
|
||||
|
||||
logger.info("Updated 'comfyui' library with current folder paths")
|
||||
if should_activate:
|
||||
logger.info("Updated 'comfyui' library with current folder paths")
|
||||
else:
|
||||
logger.info(
|
||||
"Updated 'comfyui' library with current folder paths without activating it"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
def _should_activate_comfy_library(
|
||||
self, libraries: Mapping[str, Any], libraries_changed: bool
|
||||
) -> bool:
|
||||
"""Return whether startup sync should make the ComfyUI library active."""
|
||||
|
||||
if libraries_changed:
|
||||
return True
|
||||
if not libraries:
|
||||
return True
|
||||
return "comfyui" in libraries and len(libraries) == 1
|
||||
|
||||
def _is_link(self, path: str) -> bool:
|
||||
try:
|
||||
if os.path.islink(path):
|
||||
|
||||
@@ -16,6 +16,10 @@ import jinja2
|
||||
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.connectivity_guard import (
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
is_expected_offline_error,
|
||||
)
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelMoveService
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
@@ -504,6 +508,11 @@ class ModelManagementHandler:
|
||||
formatted_metadata = await self._service.format_response(model_data)
|
||||
return web.json_response({"success": True, "metadata": formatted_metadata})
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -550,6 +559,11 @@ class ModelManagementHandler:
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -910,7 +924,7 @@ class ModelQueryHandler:
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
if limit < 1 or limit > 100:
|
||||
if limit < 0 or limit > 100:
|
||||
limit = 20
|
||||
base_models = await self._service.get_base_models(limit)
|
||||
return web.json_response({"success": True, "base_models": base_models})
|
||||
@@ -1858,6 +1872,11 @@ class ModelUpdateHandler:
|
||||
status=429,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -1946,9 +1965,12 @@ class ModelUpdateHandler:
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error(
|
||||
"Failed to refresh model updates: %s", exc, exc_info=True
|
||||
)
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
serialized_records = []
|
||||
|
||||
@@ -329,6 +329,7 @@ class RecipeQueryHandler:
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
@@ -344,6 +345,8 @@ class RecipeQueryHandler:
|
||||
for model, count in base_model_counts.items()
|
||||
]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
if limit > 0:
|
||||
sorted_models = sorted_models[:limit]
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
|
||||
570
py/services/aria2_downloader.py
Normal file
570
py/services/aria2_downloader.py
Normal file
@@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .downloader import DownloadProgress, get_downloader
|
||||
from .aria2_transfer_state import Aria2TransferStateStore
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CIVITAI_DOWNLOAD_URL_PREFIXES = (
|
||||
"https://civitai.com/api/download/",
|
||||
"https://civitai.red/api/download/",
|
||||
)
|
||||
|
||||
|
||||
class Aria2Error(RuntimeError):
|
||||
"""Raised when aria2 integration fails."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aria2Transfer:
|
||||
"""Track an aria2 download registered by the Python coordinator."""
|
||||
|
||||
gid: str
|
||||
save_path: str
|
||||
|
||||
|
||||
class Aria2Downloader:
|
||||
"""Manage an aria2 RPC daemon for experimental model downloads."""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "Aria2Downloader":
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self._process: Optional[asyncio.subprocess.Process] = None
|
||||
self._rpc_port: Optional[int] = None
|
||||
self._rpc_secret = ""
|
||||
self._rpc_url = ""
|
||||
self._rpc_session: Optional[aiohttp.ClientSession] = None
|
||||
self._rpc_session_lock = asyncio.Lock()
|
||||
self._process_lock = asyncio.Lock()
|
||||
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||
self._poll_interval = 0.5
|
||||
self._state_store = Aria2TransferStateStore()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._process is not None and self._process.returncode is None
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
progress_callback=None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Download a file using aria2 RPC and wait for completion."""
|
||||
|
||||
await self._ensure_process()
|
||||
save_path = os.path.abspath(save_path)
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
||||
gid = await self._schedule_download(
|
||||
url,
|
||||
save_path,
|
||||
download_id=download_id,
|
||||
headers=headers,
|
||||
)
|
||||
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
self._transfers[download_id] = transfer
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self.get_status(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
|
||||
async def _schedule_download(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
save_dir = os.path.dirname(save_path)
|
||||
out_name = os.path.basename(save_path)
|
||||
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resolved_url = url
|
||||
request_headers = headers
|
||||
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
|
||||
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
|
||||
if resolved_url != url:
|
||||
request_headers = None
|
||||
logger.debug(
|
||||
"Resolved Civitai download %s to signed URL for aria2",
|
||||
download_id,
|
||||
)
|
||||
|
||||
options: Dict[str, str] = {
|
||||
"dir": save_dir,
|
||||
"out": out_name,
|
||||
"continue": "true",
|
||||
"max-connection-per-server": "4",
|
||||
"split": "4",
|
||||
"min-split-size": "1M",
|
||||
"allow-overwrite": "true",
|
||||
"auto-file-renaming": "false",
|
||||
"file-allocation": "none",
|
||||
}
|
||||
if request_headers:
|
||||
options["header"] = [
|
||||
f"{key}: {value}" for key, value in request_headers.items()
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
|
||||
download_id,
|
||||
save_path,
|
||||
bool(request_headers),
|
||||
resolved_url != url,
|
||||
)
|
||||
|
||||
try:
|
||||
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
|
||||
except Exception as exc:
|
||||
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||
|
||||
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||
await self._state_store.upsert(
|
||||
download_id,
|
||||
{
|
||||
"gid": gid,
|
||||
"save_path": save_path,
|
||||
"status": "downloading",
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
return gid
|
||||
|
||||
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the raw aria2 status payload for a known download."""
|
||||
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
|
||||
except Exception as exc:
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
||||
except Exception as exc:
|
||||
message = str(exc)
|
||||
if "cannot be found" in message.lower() or "not found" in message.lower():
|
||||
return None
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
||||
await self._ensure_process()
|
||||
self._transfers[download_id] = Aria2Transfer(
|
||||
gid=gid,
|
||||
save_path=os.path.abspath(save_path),
|
||||
)
|
||||
|
||||
async def reassign_transfer(
|
||||
self, from_download_id: str, to_download_id: str
|
||||
) -> Optional[Aria2Transfer]:
|
||||
transfer = self._transfers.get(from_download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
self._transfers[to_download_id] = transfer
|
||||
if from_download_id != to_download_id:
|
||||
self._transfers.pop(from_download_id, None)
|
||||
return transfer
|
||||
|
||||
async def has_transfer(self, download_id: str) -> bool:
|
||||
return download_id in self._transfers
|
||||
|
||||
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.forcePause", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "paused"})
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.unpause", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "downloading"})
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.forceRemove", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.remove(download_id)
|
||||
return {"success": True, "message": "Download cancelled successfully"}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Shut down the RPC process and session."""
|
||||
|
||||
if self._rpc_session is not None:
|
||||
await self._rpc_session.close()
|
||||
self._rpc_session = None
|
||||
|
||||
process = self._process
|
||||
self._process = None
|
||||
self._transfers.clear()
|
||||
|
||||
if process is None:
|
||||
return
|
||||
|
||||
if process.returncode is None:
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
|
||||
try:
|
||||
result = callback(snapshot, snapshot)
|
||||
except TypeError:
|
||||
result = callback(snapshot.percent_complete)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
elif hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
|
||||
completed = self._parse_int(status.get("completedLength"))
|
||||
total = self._parse_int(status.get("totalLength"))
|
||||
speed = float(self._parse_int(status.get("downloadSpeed")))
|
||||
percent = 0.0
|
||||
if total > 0:
|
||||
percent = (completed / total) * 100.0
|
||||
|
||||
return DownloadProgress(
|
||||
percent_complete=max(0.0, min(percent, 100.0)),
|
||||
bytes_downloaded=completed,
|
||||
total_bytes=total or None,
|
||||
bytes_per_second=speed,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
|
||||
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
|
||||
files = status.get("files")
|
||||
if isinstance(files, list) and files:
|
||||
first = files[0]
|
||||
if isinstance(first, dict):
|
||||
candidate = first.get("path")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
return default_path
|
||||
|
||||
@staticmethod
|
||||
def _parse_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
async def _resolve_authenticated_redirect_url(
|
||||
self,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
) -> str:
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
request_headers = dict(downloader.default_headers)
|
||||
request_headers.update(headers)
|
||||
request_headers["Accept-Encoding"] = "identity"
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
url,
|
||||
headers=request_headers,
|
||||
allow_redirects=False,
|
||||
proxy=downloader.proxy_url,
|
||||
) as response:
|
||||
if response.status in {301, 302, 303, 307, 308}:
|
||||
location = response.headers.get("Location")
|
||||
if location:
|
||||
return location
|
||||
raise Aria2Error(
|
||||
"Authenticated Civitai redirect did not include a Location header"
|
||||
)
|
||||
|
||||
if response.status == 200:
|
||||
return url
|
||||
|
||||
body = await response.text()
|
||||
raise Aria2Error(
|
||||
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
|
||||
)
|
||||
except aiohttp.ClientError as exc:
|
||||
raise Aria2Error(
|
||||
f"Failed to resolve authenticated Civitai redirect: {exc}"
|
||||
) from exc
|
||||
|
||||
async def _ensure_process(self) -> None:
|
||||
async with self._process_lock:
|
||||
if self.is_running and await self._ping():
|
||||
return
|
||||
|
||||
await self.close()
|
||||
|
||||
executable = self._resolve_executable()
|
||||
self._rpc_port = self._find_free_port()
|
||||
self._rpc_secret = secrets.token_hex(16)
|
||||
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
||||
|
||||
command = [
|
||||
executable,
|
||||
"--enable-rpc=true",
|
||||
"--rpc-listen-all=false",
|
||||
f"--rpc-listen-port={self._rpc_port}",
|
||||
f"--rpc-secret={self._rpc_secret}",
|
||||
"--check-certificate=true",
|
||||
"--allow-overwrite=true",
|
||||
"--auto-file-renaming=false",
|
||||
"--file-allocation=none",
|
||||
"--max-concurrent-downloads=5",
|
||||
"--continue=true",
|
||||
"--daemon=false",
|
||||
"--quiet=true",
|
||||
f"--stop-with-process={os.getpid()}",
|
||||
]
|
||||
|
||||
logger.info("Starting aria2 RPC daemon from %s", executable)
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
await self._wait_until_ready()
|
||||
|
||||
def _resolve_executable(self) -> str:
|
||||
settings = get_settings_manager()
|
||||
configured_path = (settings.get("aria2c_path") or "").strip()
|
||||
candidate = configured_path or "aria2c"
|
||||
|
||||
resolved = shutil.which(candidate)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
if configured_path and os.path.isfile(configured_path) and os.access(
|
||||
configured_path, os.X_OK
|
||||
):
|
||||
return configured_path
|
||||
|
||||
raise Aria2Error(
|
||||
"aria2c executable was not found. Install aria2 or configure aria2c_path."
|
||||
)
|
||||
|
||||
async def _wait_until_ready(self) -> None:
|
||||
assert self._process is not None
|
||||
|
||||
start_time = asyncio.get_running_loop().time()
|
||||
last_error = ""
|
||||
while asyncio.get_running_loop().time() - start_time < 10.0:
|
||||
if self._process.returncode is not None:
|
||||
stderr_output = ""
|
||||
if self._process.stderr is not None:
|
||||
try:
|
||||
stderr_output = (
|
||||
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
|
||||
).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
stderr_output = ""
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
|
||||
)
|
||||
|
||||
try:
|
||||
if await self._ping():
|
||||
return
|
||||
except Exception as exc: # pragma: no cover - startup race
|
||||
last_error = str(exc)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
raise Aria2Error(
|
||||
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
|
||||
)
|
||||
|
||||
async def _ping(self) -> bool:
|
||||
try:
|
||||
result = await self._rpc_call("aria2.getVersion", [])
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return isinstance(result, dict)
|
||||
|
||||
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
|
||||
if not self._rpc_url:
|
||||
raise Aria2Error("aria2 RPC endpoint is not initialized")
|
||||
|
||||
session = await self._get_rpc_session()
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": secrets.token_hex(8),
|
||||
"method": method,
|
||||
"params": [f"token:{self._rpc_secret}", *params],
|
||||
}
|
||||
|
||||
async with session.post(self._rpc_url, json=payload) as response:
|
||||
text = await response.text()
|
||||
|
||||
try:
|
||||
body = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
body = None
|
||||
|
||||
if body is None:
|
||||
if response.status != 200:
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
|
||||
)
|
||||
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
|
||||
|
||||
if "error" in body:
|
||||
error = body["error"] or {}
|
||||
code = error.get("code") if isinstance(error, dict) else None
|
||||
message = error.get("message") if isinstance(error, dict) else str(error)
|
||||
logger.error(
|
||||
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
|
||||
method,
|
||||
response.status,
|
||||
code,
|
||||
message,
|
||||
)
|
||||
status_message = (
|
||||
f"aria2 RPC {method} failed with status {response.status}: {message}"
|
||||
if response.status != 200
|
||||
else message
|
||||
)
|
||||
raise Aria2Error(status_message or "Unknown aria2 RPC error")
|
||||
|
||||
if response.status != 200:
|
||||
logger.error(
|
||||
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
|
||||
method,
|
||||
response.status,
|
||||
body,
|
||||
)
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC {method} returned unexpected status {response.status}"
|
||||
)
|
||||
|
||||
return body.get("result")
|
||||
|
||||
async def _get_rpc_session(self) -> aiohttp.ClientSession:
|
||||
if self._rpc_session is None or self._rpc_session.closed:
|
||||
async with self._rpc_session_lock:
|
||||
if self._rpc_session is None or self._rpc_session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._rpc_session
|
||||
|
||||
@staticmethod
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(1)
|
||||
return int(sock.getsockname()[1])
|
||||
|
||||
|
||||
async def get_aria2_downloader() -> Aria2Downloader:
|
||||
"""Get the singleton aria2 downloader."""
|
||||
|
||||
return await Aria2Downloader.get_instance()
|
||||
108
py/services/aria2_transfer_state.py
Normal file
108
py/services/aria2_transfer_state.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..utils.cache_paths import get_cache_base_dir
|
||||
|
||||
|
||||
def get_aria2_state_path() -> str:
|
||||
base_dir = get_cache_base_dir(create=True)
|
||||
state_dir = os.path.join(base_dir, "aria2")
|
||||
os.makedirs(state_dir, exist_ok=True)
|
||||
return os.path.join(state_dir, "downloads.json")
|
||||
|
||||
|
||||
class Aria2TransferStateStore:
|
||||
"""Persist aria2 transfer metadata needed for restart recovery."""
|
||||
|
||||
_locks_by_path: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def __init__(self, state_path: Optional[str] = None) -> None:
|
||||
self._state_path = os.path.abspath(state_path or get_aria2_state_path())
|
||||
self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock())
|
||||
|
||||
def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
with open(self._state_path, "r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
|
||||
normalized: Dict[str, Dict[str, Any]] = {}
|
||||
for download_id, entry in data.items():
|
||||
if isinstance(download_id, str) and isinstance(entry, dict):
|
||||
normalized[download_id] = entry
|
||||
return normalized
|
||||
|
||||
def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None:
|
||||
directory = os.path.dirname(self._state_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
temp_path = f"{self._state_path}.tmp"
|
||||
with open(temp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True)
|
||||
os.replace(temp_path, self._state_path)
|
||||
|
||||
async def load_all(self) -> Dict[str, Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return deepcopy(self._read_all_unlocked())
|
||||
|
||||
async def get(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return deepcopy(self._read_all_unlocked().get(download_id))
|
||||
|
||||
async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
current = data.get(download_id, {})
|
||||
current.update(payload)
|
||||
data[download_id] = current
|
||||
self._write_all_unlocked(data)
|
||||
return deepcopy(current)
|
||||
|
||||
async def remove(self, download_id: str) -> None:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
if download_id in data:
|
||||
del data[download_id]
|
||||
self._write_all_unlocked(data)
|
||||
|
||||
async def find_by_save_path(
|
||||
self, save_path: str, *, exclude_download_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
normalized_target = os.path.abspath(save_path)
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
for download_id, entry in data.items():
|
||||
if exclude_download_id and download_id == exclude_download_id:
|
||||
continue
|
||||
candidate = entry.get("save_path")
|
||||
if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target:
|
||||
result = dict(entry)
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
return None
|
||||
|
||||
async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
existing = data.get(from_download_id)
|
||||
if existing is None:
|
||||
return None
|
||||
updated = dict(existing)
|
||||
updated["download_id"] = to_download_id
|
||||
data[to_download_id] = updated
|
||||
if from_download_id != to_download_id:
|
||||
data.pop(from_download_id, None)
|
||||
self._write_all_unlocked(data)
|
||||
return deepcopy(updated)
|
||||
@@ -3,6 +3,11 @@ import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||
from .connectivity_guard import (
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
is_expected_offline_error,
|
||||
is_offline_cooldown_error,
|
||||
)
|
||||
from .model_metadata_provider import (
|
||||
CivitaiModelMetadataProvider,
|
||||
ModelMetadataProviderManager,
|
||||
@@ -65,6 +70,8 @@ class CivitaiClient:
|
||||
if result.provider is None:
|
||||
result.provider = "civitai_api"
|
||||
raise result
|
||||
if not success and is_offline_cooldown_error(result):
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||
return success, result
|
||||
|
||||
@staticmethod
|
||||
@@ -124,6 +131,8 @@ class CivitaiClient:
|
||||
)
|
||||
if not success:
|
||||
message = str(version)
|
||||
if is_expected_offline_error(message):
|
||||
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||
if "not found" in message.lower():
|
||||
return None, "Model not found"
|
||||
|
||||
@@ -164,6 +173,9 @@ class CivitaiClient:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
if is_expected_offline_error(str(e)):
|
||||
logger.debug("Preview download skipped due to offline state.")
|
||||
return False
|
||||
logger.error(f"Download Error: {str(e)}")
|
||||
return False
|
||||
|
||||
@@ -207,6 +219,9 @@ class CivitaiClient:
|
||||
message = self._extract_error_message(result)
|
||||
if message and "not found" in message.lower():
|
||||
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
||||
if is_expected_offline_error(message):
|
||||
logger.info("Civitai request skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||
return None
|
||||
if message:
|
||||
raise RuntimeError(message)
|
||||
return None
|
||||
@@ -357,6 +372,8 @@ class CivitaiClient:
|
||||
)
|
||||
if success:
|
||||
return data
|
||||
if is_expected_offline_error(data):
|
||||
return None
|
||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||
return None
|
||||
|
||||
@@ -371,6 +388,8 @@ class CivitaiClient:
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
if is_expected_offline_error(version):
|
||||
return None
|
||||
|
||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||
return None
|
||||
@@ -386,6 +405,8 @@ class CivitaiClient:
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
if is_expected_offline_error(version):
|
||||
return None
|
||||
|
||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||
return None
|
||||
@@ -473,6 +494,8 @@ class CivitaiClient:
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if is_expected_offline_error(result):
|
||||
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
@@ -507,6 +530,8 @@ class CivitaiClient:
|
||||
success, result = await self._make_request("GET", url, use_auth=True)
|
||||
|
||||
if not success:
|
||||
if is_expected_offline_error(result):
|
||||
return None
|
||||
logger.error(
|
||||
"Failed to fetch image info for ID %s from civitai.red: %s",
|
||||
image_id,
|
||||
@@ -566,6 +591,9 @@ class CivitaiClient:
|
||||
)
|
||||
|
||||
if not success:
|
||||
if is_expected_offline_error(result):
|
||||
logger.info("User model fetch skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||
return None
|
||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||
return None
|
||||
|
||||
|
||||
204
py/services/connectivity_guard.py
Normal file
204
py/services/connectivity_guard.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""In-memory connectivity guard to suppress repeated network retries when offline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OFFLINE_COOLDOWN_ERROR = "offline_cooldown"
|
||||
OFFLINE_FRIENDLY_MESSAGE = "Network offline, will retry automatically later"
|
||||
|
||||
|
||||
def is_offline_cooldown_error(value: Any) -> bool:
|
||||
"""Return True when a response payload represents guard short-circuit."""
|
||||
return isinstance(value, str) and value == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
|
||||
def is_expected_offline_error(value: Any) -> bool:
|
||||
"""Return True when payload is an expected offline-related result."""
|
||||
if is_offline_cooldown_error(value):
|
||||
return True
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
normalized = value.lower()
|
||||
return "network offline" in normalized or "offline" in normalized
|
||||
|
||||
|
||||
class ConnectivityGuard:
|
||||
"""Tracks network failures and gates outbound requests during cooldown."""
|
||||
|
||||
_instance: "ConnectivityGuard | None" = None
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "ConnectivityGuard":
|
||||
async with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
self._initialized = True
|
||||
self._default_destination = "__global__"
|
||||
self._destination_states: dict[str, _DestinationState] = {
|
||||
self._default_destination: _DestinationState()
|
||||
}
|
||||
self.base_backoff_seconds = 30
|
||||
self.max_backoff_seconds = 300
|
||||
self.failure_threshold = 3
|
||||
|
||||
@property
|
||||
def online(self) -> bool:
|
||||
return self._state_for_destination(None).online
|
||||
|
||||
@online.setter
|
||||
def online(self, value: bool) -> None:
|
||||
self._state_for_destination(None).online = value
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._state_for_destination(None).failure_count
|
||||
|
||||
@failure_count.setter
|
||||
def failure_count(self, value: int) -> None:
|
||||
self._state_for_destination(None).failure_count = value
|
||||
|
||||
@property
|
||||
def cooldown_until(self) -> datetime | None:
|
||||
return self._state_for_destination(None).cooldown_until
|
||||
|
||||
@cooldown_until.setter
|
||||
def cooldown_until(self, value: datetime | None) -> None:
|
||||
self._state_for_destination(None).cooldown_until = value
|
||||
|
||||
def _now(self) -> datetime:
|
||||
return datetime.now()
|
||||
|
||||
def _normalize_destination(self, destination: str | None) -> str:
|
||||
if destination is None or not destination.strip():
|
||||
return self._default_destination
|
||||
return destination.lower().strip()
|
||||
|
||||
def _state_for_destination(self, destination: str | None) -> "_DestinationState":
|
||||
destination_key = self._normalize_destination(destination)
|
||||
if destination_key not in self._destination_states:
|
||||
self._destination_states[destination_key] = _DestinationState()
|
||||
return self._destination_states[destination_key]
|
||||
|
||||
def in_cooldown(self, destination: str | None = None) -> bool:
|
||||
state = self._state_for_destination(destination)
|
||||
if state.cooldown_until is None:
|
||||
return False
|
||||
return self._now() < state.cooldown_until
|
||||
|
||||
def cooldown_remaining_seconds(self, destination: str | None = None) -> float:
|
||||
state = self._state_for_destination(destination)
|
||||
if state.cooldown_until is None:
|
||||
return 0.0
|
||||
return max(0.0, (state.cooldown_until - self._now()).total_seconds())
|
||||
|
||||
def should_block_request(self, destination: str | None = None) -> bool:
|
||||
return self.in_cooldown(destination)
|
||||
|
||||
def register_success(self, destination: str | None = None) -> None:
|
||||
destination_key = self._normalize_destination(destination)
|
||||
state = self._state_for_destination(destination_key)
|
||||
was_offline = (not state.online) or state.cooldown_until is not None
|
||||
state.online = True
|
||||
state.failure_count = 0
|
||||
state.cooldown_until = None
|
||||
if was_offline:
|
||||
logger.info(
|
||||
"Connectivity restored for destination '%s'; requests resumed.",
|
||||
destination_key,
|
||||
)
|
||||
|
||||
def register_network_failure(
|
||||
self, exc: Exception, destination: str | None = None
|
||||
) -> None:
|
||||
destination_key = self._normalize_destination(destination)
|
||||
state = self._state_for_destination(destination_key)
|
||||
state.online = False
|
||||
state.failure_count += 1
|
||||
|
||||
if state.failure_count < self.failure_threshold:
|
||||
logger.debug(
|
||||
"Network failure tracked for destination '%s' (%d/%d): %s",
|
||||
destination_key,
|
||||
state.failure_count,
|
||||
self.failure_threshold,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
retry_step = state.failure_count - self.failure_threshold
|
||||
backoff = min(
|
||||
self.max_backoff_seconds,
|
||||
self.base_backoff_seconds * (2**retry_step),
|
||||
)
|
||||
should_log_warning = not self.in_cooldown(destination_key)
|
||||
state.cooldown_until = self._now() + timedelta(seconds=backoff)
|
||||
|
||||
if should_log_warning:
|
||||
logger.warning(
|
||||
"Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.",
|
||||
destination_key,
|
||||
int(backoff),
|
||||
state.failure_count,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.",
|
||||
destination_key,
|
||||
state.failure_count,
|
||||
int(backoff),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_network_unreachable_error(exc: Exception) -> bool:
|
||||
"""Return whether the exception should count as connectivity failure."""
|
||||
if isinstance(exc, asyncio.CancelledError):
|
||||
return False
|
||||
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
asyncio.TimeoutError,
|
||||
TimeoutError,
|
||||
ConnectionRefusedError,
|
||||
socket.gaierror,
|
||||
aiohttp.ServerTimeoutError,
|
||||
aiohttp.ConnectionTimeoutError,
|
||||
aiohttp.ClientConnectorError,
|
||||
aiohttp.ClientConnectionError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
if isinstance(exc, OSError) and exc.errno in {
|
||||
errno.ENETUNREACH,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ETIMEDOUT,
|
||||
errno.ECONNREFUSED,
|
||||
}:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DestinationState:
|
||||
online: bool = True
|
||||
failure_count: int = 0
|
||||
cooldown_until: datetime | None = None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,8 +18,14 @@ from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from email.utils import parsedate_to_datetime
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from .connectivity_guard import (
|
||||
OFFLINE_COOLDOWN_ERROR,
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
ConnectivityGuard,
|
||||
)
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -138,7 +144,7 @@ class Downloader:
|
||||
self.chunk_size = (
|
||||
16 * 1024 * 1024
|
||||
) # 16MB chunks to balance I/O reduction and memory usage
|
||||
self.max_retries = 5
|
||||
self.max_retries = self._resolve_max_retries()
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
self.stall_timeout = self._resolve_stall_timeout()
|
||||
@@ -192,6 +198,18 @@ class Downloader:
|
||||
|
||||
return max(30.0, timeout_value)
|
||||
|
||||
def _resolve_max_retries(self) -> int:
|
||||
"""Determine max retry count from environment while preserving defaults."""
|
||||
default_retries = 5
|
||||
raw_value = os.environ.get("COMFYUI_DOWNLOAD_MAX_RETRIES")
|
||||
|
||||
try:
|
||||
retries = int(raw_value)
|
||||
except (TypeError, ValueError):
|
||||
retries = default_retries
|
||||
|
||||
return max(0, retries)
|
||||
|
||||
def _should_refresh_session(self) -> bool:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
@@ -334,6 +352,7 @@ class Downloader:
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
total_size = 0
|
||||
range_redirect_retry_urls: set[str] = set()
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
@@ -372,6 +391,23 @@ class Downloader:
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
redirected_url = str(response.url)
|
||||
if (
|
||||
allow_resume
|
||||
and response.history
|
||||
and redirected_url
|
||||
and redirected_url != url
|
||||
and redirected_url not in range_redirect_retry_urls
|
||||
):
|
||||
range_redirect_retry_urls.add(redirected_url)
|
||||
logger.info(
|
||||
"Range request was not honored after redirect; retrying final URL directly: %s",
|
||||
redirected_url,
|
||||
)
|
||||
url = redirected_url
|
||||
response.release()
|
||||
continue
|
||||
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning(
|
||||
"Server doesn't support range requests, restarting download"
|
||||
@@ -571,37 +607,53 @@ class Downloader:
|
||||
expected_size = total_size if total_size > 0 else None
|
||||
|
||||
integrity_error: Optional[str] = None
|
||||
resumable_incomplete = False
|
||||
if final_size <= 0:
|
||||
integrity_error = "Downloaded file is empty"
|
||||
elif expected_size is not None and final_size != expected_size:
|
||||
integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
|
||||
resumable_incomplete = (
|
||||
allow_resume
|
||||
and part_path != save_path
|
||||
and final_size > 0
|
||||
and final_size < expected_size
|
||||
)
|
||||
|
||||
if integrity_error is not None:
|
||||
logger.error(
|
||||
log_fn = logger.warning if resumable_incomplete else logger.error
|
||||
log_fn(
|
||||
"Download integrity check failed for %s: %s",
|
||||
save_path,
|
||||
integrity_error,
|
||||
)
|
||||
|
||||
# Remove the corrupted payload so future attempts start fresh
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.remove(part_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete corrupted download %s: %s",
|
||||
part_path,
|
||||
remove_error,
|
||||
)
|
||||
if part_path != save_path and os.path.exists(save_path):
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete target file %s after integrity error: %s",
|
||||
save_path,
|
||||
remove_error,
|
||||
)
|
||||
if resumable_incomplete:
|
||||
logger.info(
|
||||
"Preserving incomplete download for resume: %s (%s/%s bytes)",
|
||||
part_path,
|
||||
final_size,
|
||||
expected_size,
|
||||
)
|
||||
else:
|
||||
# Remove corrupted payloads that cannot be safely resumed.
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.remove(part_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete corrupted download %s: %s",
|
||||
part_path,
|
||||
remove_error,
|
||||
)
|
||||
if part_path != save_path and os.path.exists(save_path):
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete target file %s after integrity error: %s",
|
||||
save_path,
|
||||
remove_error,
|
||||
)
|
||||
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
@@ -611,8 +663,16 @@ class Downloader:
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
resume_offset = 0
|
||||
total_size = 0
|
||||
if resumable_incomplete and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
total_size = expected_size or 0
|
||||
logger.info(
|
||||
"Will resume incomplete download from byte %s",
|
||||
resume_offset,
|
||||
)
|
||||
else:
|
||||
resume_offset = 0
|
||||
total_size = 0
|
||||
await self._create_session()
|
||||
continue
|
||||
|
||||
@@ -743,6 +803,11 @@ class Downloader:
|
||||
Returns:
|
||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||
"""
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
destination = self._guard_destination(url)
|
||||
if guard.should_block_request(destination):
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE, None
|
||||
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
@@ -765,6 +830,7 @@ class Downloader:
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
guard.register_success(destination)
|
||||
if return_headers:
|
||||
return True, content, dict(response.headers)
|
||||
else:
|
||||
@@ -783,6 +849,12 @@ class Downloader:
|
||||
return False, error_msg, None
|
||||
|
||||
except Exception as e:
|
||||
if guard.is_network_unreachable_error(e):
|
||||
guard.register_network_failure(e, destination)
|
||||
if guard.should_block_request(destination):
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE, None
|
||||
logger.debug("Network unavailable during memory download: %s", e)
|
||||
return False, str(e), None
|
||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||
return False, str(e), None
|
||||
|
||||
@@ -803,6 +875,11 @@ class Downloader:
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||
"""
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
destination = self._guard_destination(url)
|
||||
if guard.should_block_request(destination):
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
@@ -824,11 +901,18 @@ class Downloader:
|
||||
url, headers=headers, proxy=self.proxy_url
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
guard.register_success(destination)
|
||||
return True, dict(response.headers)
|
||||
else:
|
||||
return False, f"Head request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
if guard.is_network_unreachable_error(e):
|
||||
guard.register_network_failure(e, destination)
|
||||
if guard.should_block_request(destination):
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
logger.debug("Network unavailable during header probe: %s", e)
|
||||
return False, str(e)
|
||||
logger.error(f"Error getting headers from {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@@ -853,6 +937,11 @@ class Downloader:
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||
"""
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
destination = self._guard_destination(url)
|
||||
if guard.should_block_request(destination):
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
@@ -876,6 +965,7 @@ class Downloader:
|
||||
method, url, headers=headers, **kwargs
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
guard.register_success(destination)
|
||||
# Try to parse as JSON, fall back to text
|
||||
try:
|
||||
data = await response.json()
|
||||
@@ -906,6 +996,12 @@ class Downloader:
|
||||
return False, f"Request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
if guard.is_network_unreachable_error(e):
|
||||
guard.register_network_failure(e, destination)
|
||||
if guard.should_block_request(destination):
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
logger.debug("Network unavailable for %s %s: %s", method, url, e)
|
||||
return False, str(e)
|
||||
logger.error(f"Error making {method} request to {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@@ -956,6 +1052,14 @@ class Downloader:
|
||||
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
|
||||
return max(0.0, delta.total_seconds())
|
||||
|
||||
@staticmethod
|
||||
def _guard_destination(url: str) -> str:
|
||||
"""Build per-destination connectivity guard scope from request URL."""
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname:
|
||||
return parsed_url.hostname.lower()
|
||||
return "unknown"
|
||||
|
||||
|
||||
# Global instance accessor
|
||||
async def get_downloader() -> Downloader:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.civitai_utils import resolve_license_payload
|
||||
from ..utils.model_utils import determine_base_model
|
||||
from .connectivity_guard import OFFLINE_FRIENDLY_MESSAGE, is_expected_offline_error
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -274,11 +275,18 @@ class MetadataSyncService:
|
||||
else "No provider returned metadata"
|
||||
)
|
||||
|
||||
resolved_error = last_error or default_error
|
||||
if is_expected_offline_error(resolved_error):
|
||||
resolved_error = OFFLINE_FRIENDLY_MESSAGE
|
||||
|
||||
error_msg = (
|
||||
f"Error fetching metadata: {last_error or default_error} "
|
||||
f"Error fetching metadata: {resolved_error} "
|
||||
f"(model_name={model_data.get('model_name', '')})"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if is_expected_offline_error(resolved_error):
|
||||
logger.info(error_msg)
|
||||
else:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
model_data["from_civitai"] = True
|
||||
@@ -347,6 +355,9 @@ class MetadataSyncService:
|
||||
return False, error_msg
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
error_msg = f"Error fetching metadata: {exc}"
|
||||
if is_expected_offline_error(str(exc)):
|
||||
logger.info(OFFLINE_FRIENDLY_MESSAGE)
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
@@ -1535,7 +1535,7 @@ class ModelScanner:
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get base models sorted by frequency"""
|
||||
"""Get base models sorted by count. If limit is 0, return all."""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
base_model_counts = {}
|
||||
@@ -1546,7 +1546,9 @@ class ModelScanner:
|
||||
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
|
||||
if limit == 0:
|
||||
return sorted_models
|
||||
return sorted_models[:limit]
|
||||
|
||||
async def get_model_info_by_name(self, name):
|
||||
|
||||
@@ -55,6 +55,8 @@ DEFAULT_KEYS_CLEANUP_THRESHOLD = 10
|
||||
DEFAULT_SETTINGS: Dict[str, Any] = {
|
||||
"civitai_api_key": "",
|
||||
"civitai_host": "civitai.com",
|
||||
"download_backend": "python",
|
||||
"aria2c_path": "",
|
||||
"use_portable_settings": False,
|
||||
"hash_chunk_size_mb": DEFAULT_HASH_CHUNK_SIZE_MB,
|
||||
"language": "en",
|
||||
@@ -761,34 +763,29 @@ class SettingsManager:
|
||||
if self._preserve_disk_template:
|
||||
return
|
||||
|
||||
folder_paths = self.settings.get("folder_paths", {})
|
||||
updated = False
|
||||
|
||||
def _check_and_auto_set(key: str, setting_key: str) -> bool:
|
||||
"""Repair default roots when empty or no longer present."""
|
||||
current = self.settings.get(setting_key, "")
|
||||
candidates = folder_paths.get(key, [])
|
||||
if not isinstance(candidates, list) or not candidates:
|
||||
primary_candidates = self._get_valid_root_candidates(key)
|
||||
if not primary_candidates:
|
||||
return False
|
||||
|
||||
# Filter valid path strings
|
||||
valid_paths = [p for p in candidates if isinstance(p, str) and p.strip()]
|
||||
if not valid_paths:
|
||||
allowed_roots = self._get_allowed_roots(key)
|
||||
if current and current in allowed_roots:
|
||||
return False
|
||||
|
||||
if current in valid_paths:
|
||||
return False
|
||||
|
||||
self.settings[setting_key] = valid_paths[0]
|
||||
self.settings[setting_key] = primary_candidates[0]
|
||||
if current:
|
||||
logger.info(
|
||||
"Repaired stale %s from '%s' to '%s'",
|
||||
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||
setting_key,
|
||||
current,
|
||||
valid_paths[0],
|
||||
primary_candidates[0],
|
||||
)
|
||||
else:
|
||||
logger.info("Auto-set %s to '%s'", setting_key, valid_paths[0])
|
||||
logger.info("Auto-set %s to '%s'", setting_key, primary_candidates[0])
|
||||
return True
|
||||
|
||||
# Process all model types
|
||||
@@ -811,6 +808,33 @@ class SettingsManager:
|
||||
else:
|
||||
self._save_settings()
|
||||
|
||||
def _get_valid_root_candidates(self, key: str) -> List[str]:
|
||||
"""Return stable root candidates, preferring primary roots over extra roots."""
|
||||
|
||||
candidates: List[str] = []
|
||||
seen: set[str] = set()
|
||||
for mapping_key in ("folder_paths", "extra_folder_paths"):
|
||||
raw_paths = self.settings.get(mapping_key, {})
|
||||
if not isinstance(raw_paths, Mapping):
|
||||
continue
|
||||
values = raw_paths.get(key, [])
|
||||
if not isinstance(values, list):
|
||||
continue
|
||||
for value in values:
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
normalized = value.strip()
|
||||
if not normalized or normalized in seen:
|
||||
continue
|
||||
seen.add(normalized)
|
||||
candidates.append(normalized)
|
||||
return candidates
|
||||
|
||||
def _get_allowed_roots(self, key: str) -> set[str]:
|
||||
"""Return all valid roots for a model type, including extra roots."""
|
||||
|
||||
return set(self._get_valid_root_candidates(key))
|
||||
|
||||
def _check_environment_variables(self) -> None:
|
||||
"""Check for environment variables and update settings if needed"""
|
||||
env_api_key = os.environ.get("CIVITAI_API_KEY")
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
height: 100%;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
/* Responsive header container for larger screens */
|
||||
@@ -65,7 +66,6 @@
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
flex-shrink: 0;
|
||||
margin-right: 1rem;
|
||||
}
|
||||
|
||||
.nav-item {
|
||||
@@ -101,7 +101,6 @@
|
||||
.header-search {
|
||||
flex: 1;
|
||||
max-width: 400px;
|
||||
margin: 0 1rem;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
@@ -288,4 +287,4 @@
|
||||
.header-search {
|
||||
flex: 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,11 +346,13 @@
|
||||
.api-key-input input {
|
||||
width: 100%;
|
||||
padding: 6px 40px 6px 10px; /* Add left padding */
|
||||
height: 20px;
|
||||
height: 32px;
|
||||
box-sizing: border-box;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.api-key-input .toggle-visibility {
|
||||
@@ -379,7 +381,8 @@
|
||||
.text-input-wrapper input {
|
||||
width: 100%;
|
||||
padding: 6px 10px;
|
||||
height: 20px;
|
||||
height: 32px;
|
||||
box-sizing: border-box;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
@@ -760,10 +763,12 @@
|
||||
}
|
||||
|
||||
.setting-control {
|
||||
width: 60%; /* Decreased slightly from 65% */
|
||||
flex: 0 0 60%;
|
||||
max-width: 60%;
|
||||
margin-bottom: 0;
|
||||
display: flex;
|
||||
justify-content: flex-end; /* Right-align all controls */
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
/* Select Control Styles */
|
||||
@@ -773,6 +778,13 @@
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.setting-control select,
|
||||
.setting-control input[type="text"],
|
||||
.setting-control input[type="password"],
|
||||
.setting-control input[type="number"] {
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.select-control select {
|
||||
width: 100%;
|
||||
max-width: 100%; /* Increased from 200px */
|
||||
@@ -781,8 +793,8 @@
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
/* Fix dark theme select dropdown text color */
|
||||
@@ -888,8 +900,8 @@ input:checked + .toggle-slider:before {
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
/* Add warning text style for settings */
|
||||
|
||||
@@ -145,7 +145,7 @@
|
||||
position: fixed;
|
||||
right: 20px;
|
||||
top: 50px; /* Position below header */
|
||||
width: 320px;
|
||||
width: 366px;
|
||||
background-color: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-base);
|
||||
@@ -197,6 +197,31 @@
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.filter-search-input {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
margin-bottom: 8px;
|
||||
padding: 8px 10px;
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.filter-search-input:focus {
|
||||
outline: none;
|
||||
border-color: var(--lora-accent);
|
||||
box-shadow: 0 0 0 2px rgba(var(--lora-accent-rgb, 76, 175, 80), 0.15);
|
||||
}
|
||||
|
||||
.filter-empty-state {
|
||||
margin-top: 8px;
|
||||
font-size: 13px;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.filter-section h4 {
|
||||
margin: 0 0 8px 0;
|
||||
font-size: 14px;
|
||||
@@ -733,4 +758,4 @@
|
||||
right: 20px;
|
||||
top: 160px; /* Adjusted for mobile layout */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,9 +240,7 @@ export class BulkManager {
|
||||
*/
|
||||
handleGlobalKeyboard(e) {
|
||||
// Skip if modal is open (handled by event manager conditions)
|
||||
// Skip if search input is focused
|
||||
const searchInput = document.getElementById('searchInput');
|
||||
if (searchInput && document.activeElement === searchInput) {
|
||||
if (this.isEditingTextInputContext(e.target)) {
|
||||
return false; // Don't handle, allow default behavior
|
||||
}
|
||||
|
||||
@@ -266,6 +264,26 @@ export class BulkManager {
|
||||
return false; // Continue with other handlers
|
||||
}
|
||||
|
||||
isEditingTextInputContext(target) {
|
||||
const activeElement = document.activeElement;
|
||||
const candidate = target instanceof Element ? target : activeElement;
|
||||
if (!candidate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const tagName = candidate.tagName?.toLowerCase();
|
||||
if (
|
||||
candidate.isContentEditable
|
||||
|| tagName === 'input'
|
||||
|| tagName === 'textarea'
|
||||
|| tagName === 'select'
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return Boolean(candidate.closest?.('#filterPanel'));
|
||||
}
|
||||
|
||||
toggleBulkMode() {
|
||||
state.bulkMode = !state.bulkMode;
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ export class FilterManager {
|
||||
this.filterPanel = document.getElementById('filterPanel');
|
||||
this.filterButton = document.getElementById('filterButton');
|
||||
this.activeFiltersCount = document.getElementById('activeFiltersCount');
|
||||
this.baseModelSearchInput = document.getElementById('baseModelSearchInput');
|
||||
this.baseModelOptions = [];
|
||||
this.tagsLoaded = false;
|
||||
|
||||
// Initialize preset manager
|
||||
@@ -49,6 +51,8 @@ export class FilterManager {
|
||||
}
|
||||
|
||||
initialize() {
|
||||
this.initializeFilterSearchInputs();
|
||||
|
||||
// Create base model filter tags if they exist
|
||||
if (document.getElementById('baseModelTags')) {
|
||||
this.createBaseModelTags();
|
||||
@@ -110,6 +114,18 @@ export class FilterManager {
|
||||
this.updateTagLogicToggleUI();
|
||||
}
|
||||
|
||||
initializeFilterSearchInputs() {
|
||||
if (this.baseModelSearchInput) {
|
||||
this.baseModelSearchInput.addEventListener('input', () => {
|
||||
this.renderBaseModelTags();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
getNormalizedSearchQuery(input) {
|
||||
return (input?.value || '').trim().toLowerCase();
|
||||
}
|
||||
|
||||
updateTagLogicToggleUI() {
|
||||
const toggleContainer = document.getElementById('tagLogicToggle');
|
||||
if (!toggleContainer) return;
|
||||
@@ -164,11 +180,6 @@ export class FilterManager {
|
||||
|
||||
tagsContainer.innerHTML = '';
|
||||
|
||||
if (!tags.length) {
|
||||
tagsContainer.innerHTML = `<div class="no-tags">No ${this.currentPage === 'recipes' ? 'recipe ' : ''}tags available</div>`;
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect existing tag names from the API response
|
||||
const existingTagNames = new Set(tags.map(t => t.tag));
|
||||
|
||||
@@ -186,6 +197,11 @@ export class FilterManager {
|
||||
});
|
||||
}
|
||||
|
||||
if (!tags.length) {
|
||||
tagsContainer.innerHTML = `<div class="no-tags">No ${this.currentPage === 'recipes' ? 'recipe ' : ''}tags available</div>`;
|
||||
return;
|
||||
}
|
||||
|
||||
tags.forEach(tag => {
|
||||
const tagEl = document.createElement('div');
|
||||
tagEl.className = 'filter-tag tag-filter';
|
||||
@@ -212,7 +228,6 @@ export class FilterManager {
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
this.applyTagElementState(tagEl, (this.filters.tags && this.filters.tags[tagName]) || 'none');
|
||||
tagsContainer.appendChild(tagEl);
|
||||
});
|
||||
|
||||
@@ -235,8 +250,8 @@ export class FilterManager {
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
this.applyTagElementState(noTagsEl, (this.filters.tags && this.filters.tags[noTagsKey]) || 'none');
|
||||
tagsContainer.appendChild(noTagsEl);
|
||||
this.updateTagSelections();
|
||||
}
|
||||
|
||||
initializeLicenseFilters() {
|
||||
@@ -323,44 +338,15 @@ export class FilterManager {
|
||||
if (!baseModelTagsContainer) return;
|
||||
|
||||
// Set the API endpoint based on current page
|
||||
const apiEndpoint = `/api/lm/${this.currentPage}/base-models`;
|
||||
const apiEndpoint = `/api/lm/${this.currentPage}/base-models?limit=0`;
|
||||
|
||||
// Fetch base models
|
||||
fetch(apiEndpoint)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success && data.base_models) {
|
||||
baseModelTagsContainer.innerHTML = '';
|
||||
|
||||
data.base_models.forEach(model => {
|
||||
const tag = document.createElement('div');
|
||||
tag.className = `filter-tag base-model-tag`;
|
||||
tag.dataset.baseModel = model.name;
|
||||
tag.innerHTML = `${model.name} <span class="tag-count">${model.count}</span>`;
|
||||
|
||||
// Add click handler to toggle selection and automatically apply
|
||||
tag.addEventListener('click', async () => {
|
||||
tag.classList.toggle('active');
|
||||
|
||||
if (tag.classList.contains('active')) {
|
||||
if (!this.filters.baseModel.includes(model.name)) {
|
||||
this.filters.baseModel.push(model.name);
|
||||
}
|
||||
} else {
|
||||
this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name);
|
||||
}
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
|
||||
// Auto-apply filter when tag is clicked
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
baseModelTagsContainer.appendChild(tag);
|
||||
});
|
||||
|
||||
// Update selections based on stored filters
|
||||
this.updateTagSelections();
|
||||
this.baseModelOptions = data.base_models;
|
||||
this.renderBaseModelTags();
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
@@ -369,6 +355,57 @@ export class FilterManager {
|
||||
});
|
||||
}
|
||||
|
||||
renderBaseModelTags() {
|
||||
const baseModelTagsContainer = document.getElementById('baseModelTags');
|
||||
const emptyState = document.getElementById('baseModelEmptyState');
|
||||
if (!baseModelTagsContainer) return;
|
||||
|
||||
baseModelTagsContainer.innerHTML = '';
|
||||
|
||||
if (!this.baseModelOptions.length) {
|
||||
baseModelTagsContainer.innerHTML = '<div class="no-tags">No base models available</div>';
|
||||
if (emptyState) {
|
||||
emptyState.hidden = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const query = this.getNormalizedSearchQuery(this.baseModelSearchInput);
|
||||
const filteredModels = query
|
||||
? this.baseModelOptions.filter(model => model.name.toLowerCase().includes(query))
|
||||
: this.baseModelOptions;
|
||||
|
||||
filteredModels.forEach(model => {
|
||||
const tag = document.createElement('div');
|
||||
tag.className = 'filter-tag base-model-tag';
|
||||
tag.dataset.baseModel = model.name;
|
||||
tag.innerHTML = `${model.name} <span class="tag-count">${model.count}</span>`;
|
||||
|
||||
tag.addEventListener('click', async () => {
|
||||
tag.classList.toggle('active');
|
||||
|
||||
if (tag.classList.contains('active')) {
|
||||
if (!this.filters.baseModel.includes(model.name)) {
|
||||
this.filters.baseModel.push(model.name);
|
||||
}
|
||||
} else {
|
||||
this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name);
|
||||
}
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
baseModelTagsContainer.appendChild(tag);
|
||||
});
|
||||
|
||||
if (emptyState) {
|
||||
emptyState.hidden = filteredModels.length > 0;
|
||||
}
|
||||
|
||||
this.updateTagSelections();
|
||||
}
|
||||
|
||||
async createModelTypeTags() {
|
||||
const modelTypeContainer = document.getElementById('modelTypeTags');
|
||||
if (!modelTypeContainer) return;
|
||||
@@ -453,6 +490,7 @@ export class FilterManager {
|
||||
|
||||
this.filterPanel.classList.remove('hidden');
|
||||
this.filterButton.classList.add('active');
|
||||
this.baseModelSearchInput?.focus();
|
||||
|
||||
// Load tags if they haven't been loaded yet
|
||||
if (!this.tagsLoaded) {
|
||||
|
||||
@@ -232,7 +232,7 @@ export class FilterPresetManager {
|
||||
|
||||
try {
|
||||
const fetchOptions = signal ? { signal } : {};
|
||||
const response = await fetch(`/api/lm/${this.currentPage}/base-models`, fetchOptions);
|
||||
const response = await fetch(`/api/lm/${this.currentPage}/base-models?limit=0`, fetchOptions);
|
||||
|
||||
if (!response.ok) throw new Error('Failed to fetch base models');
|
||||
|
||||
|
||||
@@ -807,6 +807,16 @@ export class SettingsManager {
|
||||
civitaiHostSelect.value = state.global.settings.civitai_host || 'civitai.com';
|
||||
}
|
||||
|
||||
const downloadBackendSelect = document.getElementById('downloadBackend');
|
||||
if (downloadBackendSelect) {
|
||||
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
|
||||
}
|
||||
|
||||
const aria2cPathInput = document.getElementById('aria2cPath');
|
||||
if (aria2cPathInput) {
|
||||
aria2cPathInput.value = state.global.settings.aria2c_path || '';
|
||||
}
|
||||
|
||||
const recipesPathInput = document.getElementById('recipesPath');
|
||||
if (recipesPathInput) {
|
||||
recipesPathInput.value = state.global.settings.recipes_path || '';
|
||||
@@ -950,9 +960,36 @@ export class SettingsManager {
|
||||
languageSelect.value = currentLanguage;
|
||||
}
|
||||
|
||||
this.loadDownloadBackendSettings();
|
||||
this.loadProxySettings();
|
||||
}
|
||||
|
||||
loadDownloadBackendSettings() {
|
||||
const downloadBackendSelect = document.getElementById('downloadBackend');
|
||||
const aria2PathSetting = document.getElementById('aria2PathSetting');
|
||||
const updateVisibility = () => {
|
||||
if (!aria2PathSetting || !downloadBackendSelect) {
|
||||
return;
|
||||
}
|
||||
aria2PathSetting.style.display = downloadBackendSelect.value === 'aria2' ? 'block' : 'none';
|
||||
};
|
||||
|
||||
if (downloadBackendSelect) {
|
||||
downloadBackendSelect.value = state.global.settings.download_backend || 'python';
|
||||
downloadBackendSelect.onchange = () => {
|
||||
updateVisibility();
|
||||
this.saveSelectSetting('downloadBackend', 'download_backend');
|
||||
};
|
||||
}
|
||||
|
||||
const aria2cPathInput = document.getElementById('aria2cPath');
|
||||
if (aria2cPathInput) {
|
||||
aria2cPathInput.value = state.global.settings.aria2c_path || '';
|
||||
}
|
||||
|
||||
updateVisibility();
|
||||
}
|
||||
|
||||
setupPriorityTagInputs() {
|
||||
['lora', 'checkpoint', 'embedding'].forEach((modelType) => {
|
||||
const textarea = document.getElementById(`${modelType}PriorityTagsInput`);
|
||||
|
||||
@@ -6,6 +6,8 @@ import { DEFAULT_PATH_TEMPLATES, DEFAULT_PRIORITY_TAG_CONFIG } from '../utils/co
|
||||
const DEFAULT_SETTINGS_BASE = Object.freeze({
|
||||
civitai_api_key: '',
|
||||
civitai_host: 'civitai.com',
|
||||
download_backend: 'python',
|
||||
aria2c_path: '',
|
||||
use_portable_settings: false,
|
||||
language: 'en',
|
||||
show_only_sfw: false,
|
||||
|
||||
@@ -25,6 +25,7 @@ export function initializeEventManagement() {
|
||||
setupPageUnloadCleanup();
|
||||
|
||||
// Register global event handlers that need coordination
|
||||
registerGlobalEventHandlers();
|
||||
registerContextMenuEvents();
|
||||
registerGlobalClickHandlers();
|
||||
|
||||
@@ -148,6 +149,10 @@ function registerGlobalClickHandlers() {
|
||||
* Register common application-wide event handlers
|
||||
*/
|
||||
export function registerGlobalEventHandlers() {
|
||||
eventManager.removeHandler('keydown', 'global-escape');
|
||||
eventManager.removeHandler('focusin', 'global-focus');
|
||||
eventManager.removeHandler('click', 'global-analytics');
|
||||
|
||||
// Escape key handler for closing modals/panels
|
||||
eventManager.addHandler('keydown', 'global-escape', (e) => {
|
||||
if (e.key === 'Escape') {
|
||||
@@ -156,6 +161,14 @@ export function registerGlobalEventHandlers() {
|
||||
modalManager.closeCurrentModal();
|
||||
return true; // Stop propagation
|
||||
}
|
||||
|
||||
if (
|
||||
window.filterManager?.filterPanel
|
||||
&& !window.filterManager.filterPanel.classList.contains('hidden')
|
||||
) {
|
||||
window.filterManager.closeFilterPanel();
|
||||
return true; // Stop propagation
|
||||
}
|
||||
|
||||
// Check if node selector is active and close it
|
||||
if (eventManager.getState('nodeSelectorActive')) {
|
||||
|
||||
@@ -145,9 +145,22 @@
|
||||
|
||||
<div class="filter-section">
|
||||
<h4>{{ t('header.filter.baseModel') }}</h4>
|
||||
<input
|
||||
type="text"
|
||||
id="baseModelSearchInput"
|
||||
class="filter-search-input"
|
||||
placeholder="{{ t('header.filter.baseModelSearchPlaceholder') }}"
|
||||
autocomplete="off"
|
||||
autocorrect="off"
|
||||
autocapitalize="none"
|
||||
spellcheck="false"
|
||||
>
|
||||
<div class="filter-tags" id="baseModelTags">
|
||||
<!-- Tags will be dynamically inserted here -->
|
||||
</div>
|
||||
<div id="baseModelEmptyState" class="filter-empty-state" hidden>
|
||||
{{ t('header.filter.noBaseModelMatches') }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="filter-section">
|
||||
<div class="filter-section-header">
|
||||
@@ -188,4 +201,4 @@
|
||||
{{ t('header.filter.clearAll') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -129,6 +129,43 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="settings-subsection">
|
||||
<div class="settings-subsection-header">
|
||||
<h4>{{ t('settings.sections.downloads') }}</h4>
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label for="downloadBackend">{{ t('settings.downloadBackend.label') }}</label>
|
||||
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.downloadBackend.help') }}"></i>
|
||||
</div>
|
||||
<div class="setting-control select-control">
|
||||
<select id="downloadBackend" onchange="settingsManager.saveSelectSetting('downloadBackend', 'download_backend')">
|
||||
<option value="python">{{ t('settings.downloadBackend.options.python') }}</option>
|
||||
<option value="aria2">{{ t('settings.downloadBackend.options.aria2') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-item" id="aria2PathSetting" style="display: none;">
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label for="aria2cPath">{{ t('settings.aria2cPath.label') }}</label>
|
||||
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aria2cPath.help') }}"></i>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<div class="text-input-wrapper">
|
||||
<input type="text"
|
||||
id="aria2cPath"
|
||||
placeholder="{{ t('settings.aria2cPath.placeholder') }}"
|
||||
onblur="settingsManager.saveInputSetting('aria2cPath', 'aria2c_path')"
|
||||
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Backup -->
|
||||
<div class="settings-subsection">
|
||||
<div class="settings-subsection-header">
|
||||
|
||||
@@ -46,6 +46,7 @@ def test_save_paths_renames_default_library(monkeypatch: pytest.MonkeyPatch, tmp
|
||||
self.delete_calls = []
|
||||
self.upsert_calls = []
|
||||
self._renamed = False
|
||||
self.active_library = "default"
|
||||
|
||||
def get_libraries(self):
|
||||
if self._renamed:
|
||||
@@ -62,6 +63,11 @@ def test_save_paths_renames_default_library(monkeypatch: pytest.MonkeyPatch, tmp
|
||||
def rename_library(self, old_name: str, new_name: str):
|
||||
self.rename_calls.append((old_name, new_name))
|
||||
self._renamed = True
|
||||
if self.active_library == old_name:
|
||||
self.active_library = new_name
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def delete_library(self, name: str): # pragma: no cover - defensive guard
|
||||
self.delete_calls.append(name)
|
||||
@@ -104,6 +110,7 @@ def test_save_paths_logs_warning_when_upsert_fails(
|
||||
class RaisingSettingsService:
|
||||
def __init__(self):
|
||||
self.upsert_attempts = []
|
||||
self.active_library = "comfyui"
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
@@ -116,6 +123,9 @@ def test_save_paths_logs_warning_when_upsert_fails(
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.upsert_attempts.append((name, payload))
|
||||
raise RuntimeError("boom")
|
||||
@@ -135,6 +145,8 @@ def test_save_paths_repairs_empty_default_roots(monkeypatch: pytest.MonkeyPatch,
|
||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||
|
||||
class FakeSettingsService:
|
||||
active_library = "comfyui"
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
"comfyui": {
|
||||
@@ -148,6 +160,9 @@ def test_save_paths_repairs_empty_default_roots(monkeypatch: pytest.MonkeyPatch,
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.name = name
|
||||
self.payload = payload
|
||||
@@ -167,6 +182,8 @@ def test_save_paths_repairs_stale_default_roots(monkeypatch: pytest.MonkeyPatch,
|
||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||
|
||||
class FakeSettingsService:
|
||||
active_library = "comfyui"
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
"comfyui": {
|
||||
@@ -180,6 +197,9 @@ def test_save_paths_repairs_stale_default_roots(monkeypatch: pytest.MonkeyPatch,
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.name = name
|
||||
self.payload = payload
|
||||
@@ -199,6 +219,8 @@ def test_save_paths_keeps_valid_default_roots(monkeypatch: pytest.MonkeyPatch, t
|
||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||
|
||||
class FakeSettingsService:
|
||||
active_library = "comfyui"
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
"comfyui": {
|
||||
@@ -212,6 +234,9 @@ def test_save_paths_keeps_valid_default_roots(monkeypatch: pytest.MonkeyPatch, t
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.name = name
|
||||
self.payload = payload
|
||||
@@ -258,6 +283,7 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
||||
self.rename_calls = []
|
||||
self.delete_calls = []
|
||||
self.upsert_calls = []
|
||||
self.active_library = "default"
|
||||
|
||||
def get_libraries(self):
|
||||
return self.libraries
|
||||
@@ -265,6 +291,8 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
||||
def rename_library(self, old_name: str, new_name: str):
|
||||
self.rename_calls.append((old_name, new_name))
|
||||
self.libraries[new_name] = self.libraries.pop(old_name)
|
||||
if self.active_library == old_name:
|
||||
self.active_library = new_name
|
||||
|
||||
def delete_library(self, name: str):
|
||||
self.delete_calls.append(name)
|
||||
@@ -273,6 +301,11 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.upsert_calls.append((name, payload))
|
||||
self.libraries[name] = {**payload}
|
||||
if payload.get("activate"):
|
||||
self.active_library = name
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
fake_settings = FakeSettingsService()
|
||||
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||
@@ -313,6 +346,156 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
||||
assert payload["activate"] is True
|
||||
|
||||
|
||||
def test_save_paths_keeps_default_roots_in_extra_paths(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||
extra_lora_dir = tmp_path / "extra_loras"
|
||||
extra_checkpoint_dir = tmp_path / "extra_checkpoints"
|
||||
extra_embedding_dir = tmp_path / "extra_embeddings"
|
||||
|
||||
for directory in (extra_lora_dir, extra_checkpoint_dir, extra_embedding_dir):
|
||||
directory.mkdir()
|
||||
|
||||
class FakeSettingsService:
|
||||
active_library = "comfyui"
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
"comfyui": {
|
||||
"folder_paths": {key: list(value) for key, value in folder_paths.items()},
|
||||
"extra_folder_paths": {
|
||||
"loras": [str(extra_lora_dir)],
|
||||
"checkpoints": [str(extra_checkpoint_dir)],
|
||||
"embeddings": [str(extra_embedding_dir)],
|
||||
},
|
||||
"default_lora_root": str(extra_lora_dir),
|
||||
"default_checkpoint_root": str(extra_checkpoint_dir),
|
||||
"default_embedding_root": str(extra_embedding_dir),
|
||||
}
|
||||
}
|
||||
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.name = name
|
||||
self.payload = payload
|
||||
|
||||
fake_settings = FakeSettingsService()
|
||||
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||
|
||||
config_module.Config()
|
||||
|
||||
assert fake_settings.name == "comfyui"
|
||||
assert fake_settings.payload["extra_folder_paths"]["loras"] == [str(extra_lora_dir).replace("\\", "/")]
|
||||
assert fake_settings.payload["extra_folder_paths"]["checkpoints"] == [
|
||||
str(extra_checkpoint_dir).replace("\\", "/")
|
||||
]
|
||||
assert fake_settings.payload["extra_folder_paths"]["embeddings"] == [
|
||||
str(extra_embedding_dir).replace("\\", "/")
|
||||
]
|
||||
assert fake_settings.payload["default_lora_root"] == str(extra_lora_dir).replace("\\", "/")
|
||||
assert fake_settings.payload["default_checkpoint_root"] == str(extra_checkpoint_dir).replace("\\", "/")
|
||||
assert fake_settings.payload["default_embedding_root"] == str(extra_embedding_dir).replace("\\", "/")
|
||||
assert fake_settings.payload["activate"] is True
|
||||
|
||||
|
||||
def test_save_paths_repairs_empty_default_roots_to_extra_paths_when_primary_missing(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||
):
|
||||
_setup_config_environment(monkeypatch, tmp_path)
|
||||
extra_lora_dir = tmp_path / "extra_loras"
|
||||
extra_lora_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(
|
||||
config_module.folder_paths,
|
||||
"get_folder_paths",
|
||||
lambda kind: [] if kind == "loras" else [],
|
||||
)
|
||||
|
||||
class FakeSettingsService:
|
||||
active_library = "comfyui"
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
"comfyui": {
|
||||
"folder_paths": {
|
||||
"loras": [],
|
||||
"checkpoints": [],
|
||||
"unet": [],
|
||||
"embeddings": [],
|
||||
},
|
||||
"extra_folder_paths": {
|
||||
"loras": [str(extra_lora_dir)],
|
||||
},
|
||||
"default_lora_root": "",
|
||||
}
|
||||
}
|
||||
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.name = name
|
||||
self.payload = payload
|
||||
|
||||
fake_settings = FakeSettingsService()
|
||||
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||
|
||||
config_module.Config()
|
||||
|
||||
assert fake_settings.name == "comfyui"
|
||||
assert fake_settings.payload["default_lora_root"] == str(extra_lora_dir).replace("\\", "/")
|
||||
|
||||
|
||||
def test_save_paths_does_not_activate_comfyui_library_when_another_library_is_active(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||
):
|
||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||
|
||||
class FakeSettingsService:
|
||||
def __init__(self):
|
||||
self.active_library = "studio"
|
||||
self.upsert_calls = []
|
||||
|
||||
def get_libraries(self):
|
||||
return {
|
||||
"studio": {
|
||||
"folder_paths": {"loras": ["/studio/loras"]},
|
||||
},
|
||||
"comfyui": {
|
||||
"folder_paths": {key: list(value) for key, value in folder_paths.items()},
|
||||
"default_lora_root": folder_paths["loras"][0],
|
||||
"default_checkpoint_root": folder_paths["checkpoints"][0],
|
||||
"default_embedding_root": folder_paths["embeddings"][0],
|
||||
},
|
||||
}
|
||||
|
||||
def rename_library(self, *_):
|
||||
raise AssertionError("rename_library should not be invoked")
|
||||
|
||||
def get_active_library_name(self):
|
||||
return self.active_library
|
||||
|
||||
def upsert_library(self, name: str, **payload):
|
||||
self.upsert_calls.append((name, payload))
|
||||
|
||||
fake_settings = FakeSettingsService()
|
||||
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||
|
||||
config_module.Config()
|
||||
|
||||
assert len(fake_settings.upsert_calls) == 1
|
||||
name, payload = fake_settings.upsert_calls[0]
|
||||
assert name == "comfyui"
|
||||
assert payload["activate"] is False
|
||||
|
||||
|
||||
def test_apply_library_settings_merges_extra_paths(monkeypatch, tmp_path):
|
||||
"""Test that apply_library_settings correctly merges folder_paths with extra_folder_paths."""
|
||||
loras_dir = tmp_path / "loras"
|
||||
|
||||
@@ -110,7 +110,10 @@ function renderControlsDom(pageKey) {
|
||||
<div class="search-option-tag active" data-option="filename"></div>
|
||||
</div>
|
||||
<div id="filterPanel" class="filter-panel hidden">
|
||||
<input id="baseModelSearchInput" />
|
||||
<div id="baseModelTags" class="filter-tags"></div>
|
||||
<div id="baseModelEmptyState" hidden></div>
|
||||
<div id="filterPresets" class="filter-presets"></div>
|
||||
<div id="modelTagsFilter" class="filter-tags"></div>
|
||||
<button class="clear-filter"></button>
|
||||
</div>
|
||||
@@ -286,6 +289,8 @@ describe('FilterManager tag and base model filters', () => {
|
||||
|
||||
const manager = new FilterManager({ page: pageKey });
|
||||
|
||||
expect(global.fetch).toHaveBeenCalledWith(`/api/lm/${pageKey}/base-models?limit=0`);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
const chip = document.querySelector('[data-base-model="SDXL"]');
|
||||
expect(chip).not.toBeNull();
|
||||
@@ -311,6 +316,259 @@ describe('FilterManager tag and base model filters', () => {
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual([]);
|
||||
expect(baseModelChip.classList.contains('active')).toBe(false);
|
||||
});
|
||||
|
||||
it('filters base model chips locally without changing selected state', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [
|
||||
{ name: 'SDXL', count: 2 },
|
||||
{ name: 'LTXV 2.3', count: 1 },
|
||||
],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { getCurrentPageState } = stateModule;
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
new FilterManager({ page: 'loras' });
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(document.querySelector('[data-base-model="LTXV 2.3"]')).not.toBeNull();
|
||||
});
|
||||
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
const ltxvChip = document.querySelector('[data-base-model="LTXV 2.3"]');
|
||||
ltxvChip.dispatchEvent(new Event('click', { bubbles: true }));
|
||||
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual(['LTXV 2.3']);
|
||||
|
||||
loadMoreWithVirtualScrollMock.mockClear();
|
||||
searchInput.value = 'sdx';
|
||||
searchInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
expect(document.querySelector('[data-base-model="SDXL"]')).not.toBeNull();
|
||||
expect(document.querySelector('[data-base-model="LTXV 2.3"]')).toBeNull();
|
||||
expect(document.getElementById('baseModelEmptyState').hidden).toBe(true);
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual(['LTXV 2.3']);
|
||||
|
||||
searchInput.value = 'zzz';
|
||||
searchInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
expect(document.getElementById('baseModelEmptyState').hidden).toBe(false);
|
||||
|
||||
searchInput.value = 'ltx';
|
||||
searchInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
const restoredChip = document.querySelector('[data-base-model="LTXV 2.3"]');
|
||||
expect(restoredChip).not.toBeNull();
|
||||
expect(restoredChip.classList.contains('active')).toBe(true);
|
||||
});
|
||||
|
||||
it('disables browser autocomplete helpers for the base model search input', async () => {
|
||||
renderControlsDom('loras');
|
||||
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
|
||||
searchInput.setAttribute('autocomplete', 'off');
|
||||
searchInput.setAttribute('autocorrect', 'off');
|
||||
searchInput.setAttribute('autocapitalize', 'none');
|
||||
searchInput.setAttribute('spellcheck', 'false');
|
||||
|
||||
expect(searchInput.getAttribute('autocomplete')).toBe('off');
|
||||
expect(searchInput.getAttribute('autocorrect')).toBe('off');
|
||||
expect(searchInput.getAttribute('autocapitalize')).toBe('none');
|
||||
expect(searchInput.getAttribute('spellcheck')).toBe('false');
|
||||
});
|
||||
|
||||
it('focuses the base model search input when opening the filter panel', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL', count: 2 }],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
const manager = new FilterManager({ page: 'loras' });
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
|
||||
expect(document.activeElement).not.toBe(searchInput);
|
||||
|
||||
manager.toggleFilterPanel();
|
||||
|
||||
expect(document.activeElement).toBe(searchInput);
|
||||
});
|
||||
|
||||
it('does not let base model search trigger bulk shortcuts', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL', count: 2 }],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { BulkManager } = await import('../../../static/js/managers/BulkManager.js');
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
const filterManager = new FilterManager({ page: 'loras' });
|
||||
const bulkManager = new BulkManager();
|
||||
const searchInput = document.getElementById('baseModelSearchInput');
|
||||
window.filterManager = filterManager;
|
||||
|
||||
searchInput.focus();
|
||||
|
||||
const bulkEvent = new KeyboardEvent('keydown', {
|
||||
key: 'b',
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
});
|
||||
Object.defineProperty(bulkEvent, 'target', { value: searchInput });
|
||||
expect(bulkManager.handleGlobalKeyboard(bulkEvent)).toBe(false);
|
||||
|
||||
const selectAllEvent = new KeyboardEvent('keydown', {
|
||||
key: 'a',
|
||||
ctrlKey: true,
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
});
|
||||
Object.defineProperty(selectAllEvent, 'target', { value: searchInput });
|
||||
expect(bulkManager.handleGlobalKeyboard(selectAllEvent)).toBe(false);
|
||||
});
|
||||
|
||||
it('closes the filter panel on Escape', async () => {
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL', count: 2 }],
|
||||
}),
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
const { eventManager } = await import('../../../static/js/utils/EventManager.js');
|
||||
const { initializeEventManagement } = await import('../../../static/js/utils/eventManagementInit.js');
|
||||
|
||||
eventManager.cleanup();
|
||||
initializeEventManagement();
|
||||
|
||||
const manager = new FilterManager({ page: 'loras' });
|
||||
window.filterManager = manager;
|
||||
manager.toggleFilterPanel();
|
||||
expect(manager.filterPanel.classList.contains('hidden')).toBe(false);
|
||||
|
||||
document.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape', bubbles: true }));
|
||||
|
||||
expect(manager.filterPanel.classList.contains('hidden')).toBe(true);
|
||||
eventManager.cleanup();
|
||||
});
|
||||
|
||||
it('applies all base models from a preset using the full base model list', async () => {
|
||||
global.fetch = vi.fn((url) => {
|
||||
if (url.includes('/base-models?limit=0')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [
|
||||
{ name: 'SDXL 1.0', count: 5 },
|
||||
{ name: 'SDXL Lightning', count: 3 },
|
||||
{ name: 'SDXL Hyper', count: 2 },
|
||||
],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.includes('/base-models')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
base_models: [{ name: 'SDXL 1.0', count: 5 }],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.includes('/top-tags')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
tags: [],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (url.includes('/model-types')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
success: true,
|
||||
model_types: [],
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({ success: true }),
|
||||
});
|
||||
});
|
||||
|
||||
renderControlsDom('loras');
|
||||
const stateModule = await import('../../../static/js/state/index.js');
|
||||
stateModule.initPageState('loras');
|
||||
stateModule.state.global.settings.filter_presets = {
|
||||
loras: [
|
||||
{
|
||||
name: 'SDXL Family',
|
||||
filters: {
|
||||
baseModel: ['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper'],
|
||||
tags: {},
|
||||
license: {},
|
||||
modelTypes: [],
|
||||
tagLogic: 'any',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const { getCurrentPageState } = stateModule;
|
||||
const { FilterManager } = await import('../../../static/js/managers/FilterManager.js');
|
||||
|
||||
const manager = new FilterManager({ page: 'loras' });
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(document.querySelector('[data-base-model="SDXL Hyper"]')).not.toBeNull();
|
||||
});
|
||||
|
||||
await manager.presetManager.applyPreset('SDXL Family');
|
||||
|
||||
expect(manager.activePreset).toBe('SDXL Family');
|
||||
expect(manager.filters.baseModel).toEqual(['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper']);
|
||||
expect(getCurrentPageState().filters.baseModel).toEqual(['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper']);
|
||||
expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledWith(true, false);
|
||||
expect(showToastMock).toHaveBeenCalledWith(
|
||||
'Preset "SDXL Family" applied',
|
||||
{},
|
||||
'success',
|
||||
);
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
describe('PageControls favorites, sorting, and duplicates scenarios', () => {
|
||||
|
||||
@@ -305,4 +305,39 @@ describe('SettingsManager library controls', () => {
|
||||
'success',
|
||||
);
|
||||
});
|
||||
|
||||
it('loads download backend settings and toggles the aria2 path field', () => {
|
||||
const manager = createManager();
|
||||
document.body.innerHTML = `
|
||||
<select id="downloadBackend">
|
||||
<option value="python">Python</option>
|
||||
<option value="aria2">aria2</option>
|
||||
</select>
|
||||
<div id="aria2PathSetting" style="display: none;"></div>
|
||||
<input id="aria2cPath" />
|
||||
`;
|
||||
|
||||
state.global.settings = {
|
||||
download_backend: 'aria2',
|
||||
aria2c_path: '/usr/bin/aria2c',
|
||||
};
|
||||
|
||||
const saveSpy = vi.spyOn(manager, 'saveSelectSetting').mockResolvedValue();
|
||||
|
||||
manager.loadDownloadBackendSettings();
|
||||
|
||||
const backendSelect = document.getElementById('downloadBackend');
|
||||
const aria2PathSetting = document.getElementById('aria2PathSetting');
|
||||
const aria2cPath = document.getElementById('aria2cPath');
|
||||
|
||||
expect(backendSelect.value).toBe('aria2');
|
||||
expect(aria2cPath.value).toBe('/usr/bin/aria2c');
|
||||
expect(aria2PathSetting.style.display).toBe('block');
|
||||
|
||||
backendSelect.value = 'python';
|
||||
backendSelect.onchange();
|
||||
|
||||
expect(aria2PathSetting.style.display).toBe('none');
|
||||
expect(saveSpy).toHaveBeenCalledWith('downloadBackend', 'download_backend');
|
||||
});
|
||||
});
|
||||
|
||||
38
tests/routes/test_model_query_handler.py
Normal file
38
tests/routes/test_model_query_handler.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.handlers.model_handlers import ModelQueryHandler
|
||||
|
||||
|
||||
class DummyService:
|
||||
def __init__(self):
|
||||
self.received_limit = None
|
||||
|
||||
async def get_base_models(self, limit):
|
||||
self.received_limit = limit
|
||||
return [{"name": "SDXL", "count": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_query_handler_accepts_limit_zero_for_base_models():
|
||||
service = DummyService()
|
||||
handler = ModelQueryHandler(service=service, logger=logging.getLogger(__name__))
|
||||
|
||||
response = await handler.get_base_models(SimpleNamespace(query={"limit": "0"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert service.received_limit == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_query_handler_rejects_negative_limit_for_base_models():
|
||||
service = DummyService()
|
||||
handler = ModelQueryHandler(service=service, logger=logging.getLogger(__name__))
|
||||
|
||||
await handler.get_base_models(SimpleNamespace(query={"limit": "-1"}))
|
||||
|
||||
assert service.received_limit == 20
|
||||
44
tests/routes/test_recipe_query_handler.py
Normal file
44
tests/routes/test_recipe_query_handler.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.routes.handlers.recipe_handlers import RecipeQueryHandler
|
||||
|
||||
|
||||
async def _noop():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recipe_query_handler_base_models_limit_zero_returns_all():
|
||||
cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": "LTXV 2.3"},
|
||||
{"base_model": "SDXL"},
|
||||
]
|
||||
)
|
||||
scanner = SimpleNamespace(get_cached_data=lambda: None)
|
||||
|
||||
async def get_cached_data():
|
||||
return cache
|
||||
|
||||
scanner.get_cached_data = get_cached_data
|
||||
|
||||
handler = RecipeQueryHandler(
|
||||
ensure_dependencies_ready=_noop,
|
||||
recipe_scanner_getter=lambda: scanner,
|
||||
format_recipe_file_url=lambda value: value,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
response = await handler.get_base_models(SimpleNamespace(query={"limit": "0"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["base_models"] == [
|
||||
{"name": "SDXL", "count": 2},
|
||||
{"name": "LTXV 2.3", "count": 1},
|
||||
]
|
||||
354
tests/services/test_aria2_downloader.py
Normal file
354
tests/services/test_aria2_downloader.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.aria2_downloader import Aria2Downloader, Aria2Error
|
||||
from py.services.aria2_transfer_state import Aria2TransferStateStore
|
||||
from py.services import aria2_transfer_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_aria2_state(monkeypatch, tmp_path):
|
||||
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||
monkeypatch.setattr(
|
||||
aria2_transfer_state,
|
||||
"get_aria2_state_path",
|
||||
lambda: str(state_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_polls_until_complete(tmp_path, monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
progress_events = []
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "active",
|
||||
"completedLength": "5",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "25",
|
||||
},
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.addUri":
|
||||
return "gid-1"
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
downloader,
|
||||
"_resolve_authenticated_redirect_url",
|
||||
AsyncMock(
|
||||
return_value="https://signed.example.com/model.safetensors?token=abc"
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
async def progress_callback(progress, snapshot=None):
|
||||
progress_events.append(snapshot.percent_complete if snapshot else progress)
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
progress_callback=progress_callback,
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert progress_events == [50.0, 100.0]
|
||||
assert downloader._transfers == {}
|
||||
assert rpc_calls[0][0] == "aria2.addUri"
|
||||
assert rpc_calls[0][1][0] == [
|
||||
"https://signed.example.com/model.safetensors?token=abc"
|
||||
]
|
||||
assert rpc_calls[0][1][1]["out"] == "model.safetensors"
|
||||
assert "header" not in rpc_calls[0][1][1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer_state_store_shares_lock_and_preserves_concurrent_updates(tmp_path):
|
||||
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||
store_a = Aria2TransferStateStore(str(state_path))
|
||||
store_b = Aria2TransferStateStore(str(state_path))
|
||||
|
||||
assert store_a._lock is store_b._lock
|
||||
|
||||
await asyncio.gather(
|
||||
store_a.upsert("download-1", {"status": "downloading", "gid": "gid-1"}),
|
||||
store_b.upsert("download-2", {"status": "paused", "gid": "gid-2"}),
|
||||
)
|
||||
|
||||
assert await store_a.get("download-1") == {"status": "downloading", "gid": "gid-1"}
|
||||
assert await store_b.get("download-2") == {"status": "paused", "gid": "gid-2"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_keeps_auth_headers_when_civitai_does_not_redirect(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.addUri":
|
||||
return "gid-1"
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(
|
||||
downloader,
|
||||
"_resolve_authenticated_redirect_url",
|
||||
AsyncMock(return_value="https://civitai.com/api/download/models/123"),
|
||||
)
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert rpc_calls[0][1][0] == ["https://civitai.com/api/download/models/123"]
|
||||
assert rpc_calls[0][1][1]["header"] == ["Authorization: Bearer token"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_cancel_forward_to_rpc(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._transfers["download-1"] = type(
|
||||
"Transfer", (), {"gid": "gid-1", "save_path": "/tmp/model.safetensors"}
|
||||
)()
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
calls.append((method, params))
|
||||
return "gid-1"
|
||||
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
|
||||
pause_result = await downloader.pause_download("download-1")
|
||||
resume_result = await downloader.resume_download("download-1")
|
||||
cancel_result = await downloader.cancel_download("download-1")
|
||||
|
||||
assert pause_result["success"] is True
|
||||
assert resume_result["success"] is True
|
||||
assert cancel_result["success"] is True
|
||||
assert calls == [
|
||||
("aria2.forcePause", ["gid-1"]),
|
||||
("aria2.unpause", ["gid-1"]),
|
||||
("aria2.forceRemove", ["gid-1"]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_reuses_existing_transfer_without_add_uri(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
save_path = tmp_path / "downloads" / "model.safetensors"
|
||||
downloader._transfers["download-1"] = type(
|
||||
"Transfer", (), {"gid": "gid-1", "save_path": str(save_path)}
|
||||
)()
|
||||
|
||||
rpc_calls = []
|
||||
statuses = iter(
|
||||
[
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "active",
|
||||
"completedLength": "5",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "25",
|
||||
},
|
||||
{
|
||||
"gid": "gid-1",
|
||||
"status": "complete",
|
||||
"completedLength": "10",
|
||||
"totalLength": "10",
|
||||
"downloadSpeed": "0",
|
||||
"files": [{"path": str(save_path)}],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_rpc_call(method, params):
|
||||
rpc_calls.append((method, params))
|
||||
if method == "aria2.tellStatus":
|
||||
return next(statuses)
|
||||
raise AssertionError(f"Unexpected RPC method: {method}")
|
||||
|
||||
monkeypatch.setattr(downloader, "_ensure_process", AsyncMock())
|
||||
monkeypatch.setattr(downloader, "_rpc_call", fake_rpc_call)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
"https://example.com/model.safetensors",
|
||||
str(save_path),
|
||||
download_id="download-1",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result == str(save_path)
|
||||
assert [call[0] for call in rpc_calls] == ["aria2.tellStatus", "aria2.tellStatus"]
|
||||
|
||||
|
||||
def test_build_progress_snapshot_normalizes_numeric_fields():
|
||||
downloader = Aria2Downloader()
|
||||
|
||||
snapshot = downloader._build_progress_snapshot(
|
||||
{
|
||||
"completedLength": "75",
|
||||
"totalLength": "100",
|
||||
"downloadSpeed": "512",
|
||||
}
|
||||
)
|
||||
|
||||
assert snapshot.percent_complete == 75.0
|
||||
assert snapshot.bytes_downloaded == 75
|
||||
assert snapshot.total_bytes == 100
|
||||
assert snapshot.bytes_per_second == 512.0
|
||||
|
||||
|
||||
def test_resolve_executable_raises_when_binary_missing(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
settings = type("Settings", (), {"get": lambda self, key, default=None: ""})()
|
||||
|
||||
monkeypatch.setattr("py.services.aria2_downloader.get_settings_manager", lambda: settings)
|
||||
monkeypatch.setattr("py.services.aria2_downloader.shutil.which", lambda _: None)
|
||||
|
||||
with pytest.raises(Aria2Error):
|
||||
downloader._resolve_executable()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rpc_call_surfaces_json_error_on_non_200(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
downloader._rpc_url = "http://127.0.0.1:6800/jsonrpc"
|
||||
downloader._rpc_secret = "secret"
|
||||
|
||||
class FakeResponse:
|
||||
status = 400
|
||||
|
||||
async def text(self):
|
||||
return (
|
||||
'{"jsonrpc":"2.0","id":"x","error":{"code":1,"message":"Unauthorized"}}'
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class FakeSession:
|
||||
def post(self, _url, json=None):
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(downloader, "_get_rpc_session", AsyncMock(return_value=FakeSession()))
|
||||
|
||||
with pytest.raises(Aria2Error) as exc_info:
|
||||
await downloader._rpc_call("aria2.addUri", [["https://example.com/file"]])
|
||||
|
||||
assert "Unauthorized" in str(exc_info.value)
|
||||
assert "aria2.addUri" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
|
||||
downloader = Aria2Downloader()
|
||||
|
||||
class FakeResponse:
|
||||
status = 307
|
||||
headers = {"Location": "https://signed.example.com/file.safetensors"}
|
||||
|
||||
async def text(self):
|
||||
return ""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class FakeSession:
|
||||
def get(self, _url, headers=None, allow_redirects=False, proxy=None):
|
||||
return FakeResponse()
|
||||
|
||||
class FakeDownloader:
|
||||
default_headers = {"User-Agent": "ComfyUI-LoRA-Manager/1.0"}
|
||||
proxy_url = None
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
async def _session():
|
||||
return FakeSession()
|
||||
|
||||
return _session()
|
||||
|
||||
fake_downloader = FakeDownloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"py.services.aria2_downloader.get_downloader",
|
||||
AsyncMock(return_value=fake_downloader),
|
||||
)
|
||||
|
||||
result = await downloader._resolve_authenticated_redirect_url(
|
||||
"https://civitai.com/api/download/models/123",
|
||||
{"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
assert result == "https://signed.example.com/file.safetensors"
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from py.services import civitai_client as civitai_client_module
|
||||
from py.services.civitai_client import CivitaiClient
|
||||
from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE
|
||||
from py.services.errors import RateLimitError, ResourceNotFoundError
|
||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||
|
||||
@@ -115,6 +116,20 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
||||
assert error == "Model not found"
|
||||
|
||||
|
||||
async def test_get_model_by_hash_handles_offline_cooldown(downloader):
|
||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("missing")
|
||||
|
||||
assert result is None
|
||||
assert error == OFFLINE_FRIENDLY_MESSAGE
|
||||
|
||||
|
||||
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||
return False, RateLimitError("limited", retry_after=4)
|
||||
|
||||
125
tests/services/test_connectivity_guard.py
Normal file
125
tests/services/test_connectivity_guard.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
import errno
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.connectivity_guard import (
|
||||
OFFLINE_COOLDOWN_ERROR,
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
ConnectivityGuard,
|
||||
)
|
||||
from py.services.downloader import Downloader
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_connectivity_guard_singleton():
|
||||
ConnectivityGuard._instance = None
|
||||
yield
|
||||
ConnectivityGuard._instance = None
|
||||
|
||||
|
||||
async def test_connectivity_guard_enters_cooldown_after_threshold():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
|
||||
assert guard.online is True
|
||||
assert guard.should_block_request() is False
|
||||
|
||||
guard.register_network_failure(OSError(errno.ENETUNREACH, "unreachable"))
|
||||
guard.register_network_failure(asyncio.TimeoutError("timeout"))
|
||||
|
||||
assert guard.should_block_request() is False
|
||||
assert guard.failure_count == 2
|
||||
|
||||
guard.register_network_failure(ConnectionRefusedError("refused"))
|
||||
|
||||
assert guard.online is False
|
||||
assert guard.failure_count == 3
|
||||
assert guard.should_block_request() is True
|
||||
assert guard.cooldown_remaining_seconds() > 0
|
||||
|
||||
|
||||
async def test_connectivity_guard_scopes_cooldown_to_destination():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
|
||||
destination_a = "civitai.com"
|
||||
destination_b = "api.github.com"
|
||||
|
||||
guard.register_network_failure(
|
||||
OSError(errno.ENETUNREACH, "unreachable"),
|
||||
destination_a,
|
||||
)
|
||||
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a)
|
||||
guard.register_network_failure(ConnectionRefusedError("refused"), destination_a)
|
||||
|
||||
assert guard.should_block_request(destination_a) is True
|
||||
assert guard.should_block_request(destination_b) is False
|
||||
|
||||
guard.register_success(destination_a)
|
||||
assert guard.should_block_request(destination_a) is False
|
||||
|
||||
|
||||
async def test_connectivity_guard_recovers_after_success():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
guard.online = False
|
||||
guard.failure_count = 5
|
||||
guard.cooldown_until = datetime.now() + timedelta(seconds=90)
|
||||
|
||||
guard.register_success()
|
||||
|
||||
assert guard.online is True
|
||||
assert guard.failure_count == 0
|
||||
assert guard.cooldown_until is None
|
||||
assert guard.should_block_request() is False
|
||||
|
||||
|
||||
async def test_downloader_short_circuits_all_request_helpers_during_cooldown():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
destination = "example.invalid"
|
||||
guard.register_network_failure(
|
||||
OSError(errno.ENETUNREACH, "unreachable"),
|
||||
destination,
|
||||
)
|
||||
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination)
|
||||
guard.register_network_failure(
|
||||
ConnectionRefusedError("refused"),
|
||||
destination,
|
||||
)
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
ok, payload = await downloader.make_request("GET", f"https://{destination}")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
ok, payload, headers = await downloader.download_to_memory(f"https://{destination}")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_FRIENDLY_MESSAGE
|
||||
assert headers is None
|
||||
|
||||
ok, payload = await downloader.get_response_headers(f"https://{destination}")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
|
||||
async def test_downloader_only_short_circuits_requests_for_same_destination():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
guard.register_network_failure(
|
||||
OSError(errno.ENETUNREACH, "unreachable"),
|
||||
"example.invalid",
|
||||
)
|
||||
guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid")
|
||||
guard.register_network_failure(
|
||||
ConnectionRefusedError("refused"),
|
||||
"example.invalid",
|
||||
)
|
||||
|
||||
downloader = Downloader()
|
||||
ok, payload = await downloader.make_request("GET", "https://example.invalid")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
assert (
|
||||
guard.should_block_request(downloader._guard_destination("https://example.com"))
|
||||
is False
|
||||
)
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
|
||||
from py.services.download_manager import DownloadManager
|
||||
from py.services import download_manager
|
||||
from py.services import aria2_transfer_state
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||
|
||||
@@ -46,6 +47,16 @@ def isolate_settings(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_aria2_state(monkeypatch, tmp_path):
|
||||
state_path = tmp_path / "cache" / "aria2" / "downloads.json"
|
||||
monkeypatch.setattr(
|
||||
aria2_transfer_state,
|
||||
"get_aria2_state_path",
|
||||
lambda: str(state_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_metadata(monkeypatch):
|
||||
class _StubMetadata:
|
||||
@@ -179,6 +190,7 @@ async def test_successful_download_uses_defaults(
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
captured.update(
|
||||
{
|
||||
@@ -268,6 +280,7 @@ async def test_download_uses_active_mirrors(
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
captured["download_urls"] = download_urls
|
||||
return {"success": True}
|
||||
@@ -288,6 +301,644 @@ async def test_download_uses_active_mirrors(
|
||||
assert captured["download_urls"] == ["https://mirror.example/file.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_cancel_delegate_to_aria2_backend(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
task = asyncio.create_task(asyncio.sleep(60))
|
||||
manager._download_tasks["download-1"] = task
|
||||
manager._pause_events["download-1"] = download_manager.DownloadStreamControl()
|
||||
manager._active_downloads["download-1"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def pause_download(self, download_id):
|
||||
self.calls.append(("pause", download_id))
|
||||
return {"success": True, "message": "paused"}
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
async def cancel_download(self, download_id):
|
||||
self.calls.append(("cancel", download_id))
|
||||
return {"success": True, "message": "cancelled"}
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return True
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
pause_result = await manager.pause_download("download-1")
|
||||
assert pause_result["success"] is True
|
||||
assert manager._active_downloads["download-1"]["status"] == "paused"
|
||||
|
||||
resume_result = await manager.resume_download("download-1")
|
||||
assert resume_result["success"] is True
|
||||
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||
|
||||
cancel_result = await manager.cancel_download("download-1")
|
||||
assert cancel_result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
assert dummy_aria2.calls == [
|
||||
("has_transfer", "download-1"),
|
||||
("pause", "download-1"),
|
||||
("has_transfer", "download-1"),
|
||||
("resume", "download-1"),
|
||||
("cancel", "download-1"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_allows_queued_aria2_task_without_transfer(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def blocked_task():
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
|
||||
task = asyncio.create_task(blocked_task())
|
||||
await started.wait()
|
||||
|
||||
manager._download_tasks["download-queued"] = task
|
||||
manager._active_downloads["download-queued"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "queued",
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def cancel_download(self, download_id):
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
result = await manager.cancel_download("download-queued")
|
||||
|
||||
assert result["success"] is True
|
||||
assert task.cancelled() or task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_queued_aria2_task_without_transfer(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
|
||||
task = asyncio.create_task(asyncio.sleep(60))
|
||||
manager._download_tasks["download-queued"] = task
|
||||
manager._pause_events["download-queued"] = download_manager.DownloadStreamControl()
|
||||
manager._active_downloads["download-queued"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "waiting",
|
||||
"bytes_per_second": 12.0,
|
||||
}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return False
|
||||
|
||||
async def pause_download(self, download_id):
|
||||
self.calls.append(("pause", download_id))
|
||||
return {"success": True, "message": "paused"}
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
pause_result = await manager.pause_download("download-queued")
|
||||
assert pause_result == {"success": True, "message": "Download paused successfully"}
|
||||
assert manager._active_downloads["download-queued"]["status"] == "paused"
|
||||
assert manager._pause_events["download-queued"].is_paused() is True
|
||||
|
||||
resume_result = await manager.resume_download("download-queued")
|
||||
assert resume_result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert manager._active_downloads["download-queued"]["status"] == "downloading"
|
||||
assert manager._pause_events["download-queued"].is_set() is True
|
||||
assert dummy_aria2.calls == [
|
||||
("has_transfer", "download-queued"),
|
||||
("has_transfer", "download-queued"),
|
||||
]
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_download_restores_persisted_aria2_task(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"save_dir": str(save_dir),
|
||||
"relative_path": "",
|
||||
"use_default_paths": False,
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
},
|
||||
)
|
||||
|
||||
created = {}
|
||||
|
||||
async def fake_download_with_semaphore(
|
||||
self,
|
||||
task_id,
|
||||
model_id,
|
||||
model_version_id,
|
||||
save_dir,
|
||||
relative_path,
|
||||
progress_callback=None,
|
||||
use_default_paths=False,
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
created.update(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"model_id": model_id,
|
||||
"model_version_id": model_version_id,
|
||||
"save_dir": save_dir,
|
||||
}
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
async def has_transfer(self, download_id):
|
||||
self.calls.append(("has_transfer", download_id))
|
||||
return False
|
||||
|
||||
async def resume_download(self, download_id):
|
||||
self.calls.append(("resume", download_id))
|
||||
return {"success": True, "message": "resumed"}
|
||||
|
||||
async def restore_transfer(self, download_id, gid, save_path):
|
||||
self.calls.append(("restore_transfer", download_id, gid, save_path))
|
||||
|
||||
dummy_aria2 = DummyAria2Downloader()
|
||||
monkeypatch.setattr(
|
||||
download_manager, "_download_with_semaphore", None, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_download_with_semaphore",
|
||||
fake_download_with_semaphore,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=dummy_aria2),
|
||||
)
|
||||
|
||||
result = await manager.resume_download("download-1")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert created["task_id"] == "download-1"
|
||||
assert created["model_version_id"] == 34
|
||||
assert manager._active_downloads["download-1"]["status"] == "downloading"
|
||||
assert manager._pause_events["download-1"].is_set() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restores_persisted_aria2_entries(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
},
|
||||
)
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
|
||||
assert downloads["downloads"] == [
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"progress": 0,
|
||||
"status": "paused",
|
||||
"error": None,
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restores_orphaned_aria2_partial_as_paused(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"gid": "missing-gid",
|
||||
},
|
||||
)
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
persisted = await manager._aria2_state_store.get("download-1")
|
||||
|
||||
assert downloads["downloads"] == [
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"progress": 0,
|
||||
"status": "paused",
|
||||
"error": None,
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
]
|
||||
assert manager._pause_events["download-1"].is_paused() is True
|
||||
assert persisted["status"] == "paused"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restarts_from_resume_context_for_active_restored_aria2(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "downloading",
|
||||
"save_path": str(save_path),
|
||||
"file_path": str(save_path),
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"gid": "gid-1",
|
||||
"resume_context": {
|
||||
"version_info": {
|
||||
"id": 34,
|
||||
"modelId": 12,
|
||||
"model": {"id": 12, "type": "LoRA", "tags": ["fantasy"]},
|
||||
"images": [],
|
||||
},
|
||||
"file_info": {
|
||||
"name": "file.safetensors",
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"downloadUrl": "https://example.com/file.safetensors",
|
||||
},
|
||||
"model_type": "lora",
|
||||
"relative_path": "",
|
||||
"save_dir": str(save_dir),
|
||||
"download_urls": ["https://example.com/file.safetensors"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
restarted = {}
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return {"gid": gid, "status": "active"}
|
||||
|
||||
async def restore_transfer(self, download_id, gid, restored_path):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
async def fake_resume_restored_aria2_download(self, download_id, record):
|
||||
restarted.update(
|
||||
{
|
||||
"download_id": download_id,
|
||||
"model_id": record.get("model_id"),
|
||||
"model_version_id": record.get("model_version_id"),
|
||||
"save_dir": record.get("save_dir"),
|
||||
"resume_context": record.get("resume_context"),
|
||||
}
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_resume_restored_aria2_download",
|
||||
fake_resume_restored_aria2_download,
|
||||
)
|
||||
execute_original = AsyncMock(side_effect=AssertionError("should not refetch metadata"))
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_execute_original_download",
|
||||
execute_original,
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
assert downloads["downloads"][0]["status"] == "downloading"
|
||||
restarted_task = manager._download_tasks["download-1"]
|
||||
await restarted_task
|
||||
|
||||
assert restarted["download_id"] == "download-1"
|
||||
assert restarted["model_id"] == 12
|
||||
assert restarted["model_version_id"] == 34
|
||||
assert restarted["save_dir"] is None
|
||||
assert restarted["resume_context"]["model_type"] == "lora"
|
||||
assert execute_original.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_downloads_restores_persisted_aria2_without_initial_save_path(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
save_path = save_dir / "file.safetensors"
|
||||
save_path.write_text("partial")
|
||||
(save_dir / "file.safetensors.aria2").write_text("control")
|
||||
|
||||
await manager._aria2_state_store.upsert(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"resume_context": {
|
||||
"version_info": {
|
||||
"id": 34,
|
||||
"modelId": 12,
|
||||
"model": {"id": 12, "type": "LoRA"},
|
||||
"images": [],
|
||||
},
|
||||
"file_info": {
|
||||
"name": "file.safetensors",
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"downloadUrl": "https://example.com/file.safetensors",
|
||||
},
|
||||
"model_type": "lora",
|
||||
"relative_path": "",
|
||||
"save_dir": str(save_dir),
|
||||
"download_urls": ["https://example.com/file.safetensors"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
class DummyAria2Downloader:
|
||||
async def get_status_by_gid(self, gid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_manager,
|
||||
"get_aria2_downloader",
|
||||
AsyncMock(return_value=DummyAria2Downloader()),
|
||||
)
|
||||
|
||||
downloads = await manager.get_active_downloads()
|
||||
persisted = await manager._aria2_state_store.get("download-1")
|
||||
|
||||
assert downloads["downloads"] == [
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"progress": 0,
|
||||
"status": "paused",
|
||||
"error": None,
|
||||
"bytes_downloaded": 0,
|
||||
"total_bytes": None,
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
]
|
||||
assert manager._active_downloads["download-1"]["file_path"] == str(save_path)
|
||||
assert persisted["save_path"] == str(save_path)
|
||||
assert persisted["file_path"] == str(save_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_restored_aria2_download_updates_terminal_status_and_cleanup(monkeypatch):
|
||||
manager = DownloadManager()
|
||||
manager._active_downloads["download-1"] = {
|
||||
"transfer_backend": "aria2",
|
||||
"status": "paused",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"bytes_per_second": 10.0,
|
||||
}
|
||||
|
||||
persist_state = AsyncMock()
|
||||
cleanup_record = AsyncMock(return_value=None)
|
||||
execute_download = AsyncMock(return_value={"success": True})
|
||||
record_history = AsyncMock(return_value=None)
|
||||
sync_version = AsyncMock(return_value=None)
|
||||
|
||||
monkeypatch.setattr(manager, "_persist_aria2_state", persist_state)
|
||||
monkeypatch.setattr(manager, "_cleanup_download_record", cleanup_record)
|
||||
monkeypatch.setattr(manager, "_execute_download", execute_download)
|
||||
monkeypatch.setattr(manager, "_record_downloaded_version_history", record_history)
|
||||
monkeypatch.setattr(manager, "_sync_downloaded_version", sync_version)
|
||||
|
||||
scheduled_tasks = []
|
||||
original_create_task = asyncio.create_task
|
||||
|
||||
def tracking_create_task(coro):
|
||||
task = original_create_task(coro)
|
||||
scheduled_tasks.append(task)
|
||||
return task
|
||||
|
||||
monkeypatch.setattr(download_manager.asyncio, "create_task", tracking_create_task)
|
||||
|
||||
result = await manager._resume_restored_aria2_download(
|
||||
"download-1",
|
||||
{
|
||||
"download_id": "download-1",
|
||||
"save_path": "/tmp/file.safetensors",
|
||||
"file_path": "/tmp/file.safetensors",
|
||||
"model_id": 12,
|
||||
"model_version_id": 34,
|
||||
"resume_context": {
|
||||
"version_info": {
|
||||
"id": 34,
|
||||
"modelId": 12,
|
||||
"model": {"id": 12},
|
||||
"images": [],
|
||||
},
|
||||
"file_info": {
|
||||
"name": "file.safetensors",
|
||||
"downloadUrl": "https://example.com/file.safetensors",
|
||||
},
|
||||
"model_type": "lora",
|
||||
"relative_path": "",
|
||||
"save_dir": "/tmp",
|
||||
"download_urls": ["https://example.com/file.safetensors"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert manager._active_downloads["download-1"]["status"] == "completed"
|
||||
assert manager._active_downloads["download-1"]["bytes_per_second"] == 0.0
|
||||
assert persist_state.await_count == 2
|
||||
assert len(scheduled_tasks) == 1
|
||||
await asyncio.gather(*scheduled_tasks)
|
||||
cleanup_record.assert_awaited_once_with("download-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_uses_captured_backend_when_settings_change(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["download_backend"] = "aria2"
|
||||
|
||||
semaphore = asyncio.Semaphore(0)
|
||||
manager._download_semaphore = semaphore
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_execute_original_download(
|
||||
self,
|
||||
model_id,
|
||||
model_version_id,
|
||||
save_dir,
|
||||
relative_path,
|
||||
progress_callback,
|
||||
use_default_paths,
|
||||
download_id=None,
|
||||
transfer_backend="python",
|
||||
source=None,
|
||||
file_params=None,
|
||||
):
|
||||
captured["transfer_backend"] = transfer_backend
|
||||
return {"success": True}
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_execute_original_download",
|
||||
fake_execute_original_download,
|
||||
)
|
||||
|
||||
download_task = asyncio.create_task(
|
||||
manager.download_from_civitai(
|
||||
model_version_id=99,
|
||||
save_dir=str(tmp_path),
|
||||
use_default_paths=True,
|
||||
progress_callback=None,
|
||||
source=None,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert len(manager._active_downloads) == 1
|
||||
download_id = next(iter(manager._active_downloads))
|
||||
assert manager._active_downloads[download_id]["transfer_backend"] == "aria2"
|
||||
|
||||
settings.settings["download_backend"] = "python"
|
||||
semaphore.release()
|
||||
|
||||
result = await download_task
|
||||
|
||||
assert result["success"] is True
|
||||
assert captured["transfer_backend"] == "aria2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_aborts_when_version_exists(
|
||||
monkeypatch, scanners, metadata_provider
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,10 +30,21 @@ class FakeStream:
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status, headers, chunks):
|
||||
def __init__(
|
||||
self,
|
||||
status,
|
||||
headers,
|
||||
chunks,
|
||||
*,
|
||||
url="https://example.com/file",
|
||||
history=None,
|
||||
):
|
||||
self.status = status
|
||||
self.headers = headers
|
||||
self.content = FakeStream(chunks)
|
||||
self.url = url
|
||||
self.history = history or []
|
||||
self.released = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@@ -41,14 +52,25 @@ class FakeResponse:
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
self.released = True
|
||||
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, responses):
|
||||
self._responses = list(responses)
|
||||
self._get_calls = 0
|
||||
self.requests = []
|
||||
|
||||
def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp
|
||||
del url, headers, allow_redirects, proxy
|
||||
self.requests.append(
|
||||
{
|
||||
"url": url,
|
||||
"headers": headers or {},
|
||||
"allow_redirects": allow_redirects,
|
||||
"proxy": proxy,
|
||||
}
|
||||
)
|
||||
response_factory = self._responses[self._get_calls]
|
||||
self._get_calls += 1
|
||||
return response_factory()
|
||||
@@ -75,7 +97,7 @@ def _build_downloader(responses, *, max_retries=0):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_fails_when_size_mismatch(tmp_path):
|
||||
async def test_download_file_preserves_incomplete_part_when_size_mismatch(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
@@ -94,7 +116,7 @@ async def test_download_file_fails_when_size_mismatch(tmp_path):
|
||||
assert success is False
|
||||
assert "mismatch" in message.lower()
|
||||
assert not target_path.exists()
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
assert Path(str(target_path) + ".part").read_bytes() == b"abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -136,7 +158,9 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
|
||||
|
||||
downloader = _build_downloader(responses)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
success, result_path = await downloader.download_file(
|
||||
"https://example.com/file", str(target_path)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
@@ -166,9 +190,77 @@ async def test_download_file_recovers_from_stall(tmp_path):
|
||||
downloader = _build_downloader(responses, max_retries=1)
|
||||
downloader.stall_timeout = 0.05
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
success, result_path = await downloader.download_file(
|
||||
"https://example.com/file", str(target_path)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
assert downloader._session._get_calls == 2
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_resumes_after_incomplete_integrity_check(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
responses = [
|
||||
lambda: FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": "6"},
|
||||
chunks=[b"abc"],
|
||||
),
|
||||
lambda: FakeResponse(
|
||||
status=206,
|
||||
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
|
||||
chunks=[b"def"],
|
||||
),
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses, max_retries=1)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == b"abcdef"
|
||||
assert downloader._session._get_calls == 2
|
||||
assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_retries_redirected_url_when_range_not_honored(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
Path(str(target_path) + ".part").write_bytes(b"abc")
|
||||
|
||||
redirected_url = "https://download.example.com/file.bin"
|
||||
first_response = FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": "6"},
|
||||
chunks=[],
|
||||
url=redirected_url,
|
||||
history=[object()],
|
||||
)
|
||||
|
||||
responses = [
|
||||
lambda: first_response,
|
||||
lambda: FakeResponse(
|
||||
status=206,
|
||||
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
|
||||
chunks=[b"def"],
|
||||
url=redirected_url,
|
||||
),
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses, max_retries=0)
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == b"abcdef"
|
||||
assert first_response.released is True
|
||||
assert downloader._session.requests[0]["headers"]["Range"] == "bytes=3-"
|
||||
assert downloader._session.requests[1]["url"] == redirected_url
|
||||
assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE
|
||||
from py.services.errors import RateLimitError
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
|
||||
@@ -243,17 +244,32 @@ async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path):
|
||||
|
||||
assert not ok
|
||||
assert "Model not found" in error
|
||||
assert model_data["from_civitai"] is False
|
||||
assert model_data["civitai_deleted"] is True
|
||||
|
||||
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
|
||||
assert model_data["hydrated"] is True
|
||||
|
||||
helpers.metadata_manager.save_metadata.assert_awaited_once()
|
||||
call_args = helpers.metadata_manager.save_metadata.await_args
|
||||
assert call_args.args[0].endswith("model.safetensors")
|
||||
assert "folder" not in call_args.args[1]
|
||||
assert call_args.args[1]["hydrated"] is True
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_update_model_returns_friendly_offline_message(tmp_path):
|
||||
helpers = build_service()
|
||||
helpers.default_provider.get_model_by_hash.return_value = (None, OFFLINE_COOLDOWN_ERROR)
|
||||
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
model_data = {
|
||||
"model_name": "Local",
|
||||
"folder": "root",
|
||||
"file_path": str(model_path),
|
||||
}
|
||||
update_cache = AsyncMock(return_value=True)
|
||||
|
||||
ok, error = await helpers.service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path=str(model_path),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
|
||||
assert ok is False
|
||||
assert error is not None
|
||||
assert OFFLINE_FRIENDLY_MESSAGE in error
|
||||
update_cache.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
52
tests/services/test_model_scanner_base_models.py
Normal file
52
tests/services/test_model_scanner_base_models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.model_scanner import ModelScanner
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, raw_data):
|
||||
self._cache = SimpleNamespace(raw_data=raw_data)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_base_models_limit_zero_returns_all_sorted():
|
||||
scanner = DummyScanner(
|
||||
[
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": "LTXV 2.3"},
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": ""},
|
||||
{},
|
||||
]
|
||||
)
|
||||
|
||||
result = await ModelScanner.get_base_models(scanner, limit=0)
|
||||
|
||||
assert result == [
|
||||
{"name": "SDXL", "count": 2},
|
||||
{"name": "LTXV 2.3", "count": 1},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_base_models_positive_limit_still_truncates():
|
||||
scanner = DummyScanner(
|
||||
[
|
||||
{"base_model": "SDXL"},
|
||||
{"base_model": "LTXV 2.3"},
|
||||
{"base_model": "Flux.1 D"},
|
||||
{"base_model": "SDXL"},
|
||||
]
|
||||
)
|
||||
|
||||
result = await ModelScanner.get_base_models(scanner, limit=2)
|
||||
|
||||
assert result == [
|
||||
{"name": "SDXL", "count": 2},
|
||||
{"name": "LTXV 2.3", "count": 1},
|
||||
]
|
||||
@@ -147,6 +147,11 @@ def test_environment_variable_overrides_settings(tmp_path, monkeypatch):
|
||||
assert mgr.get("civitai_api_key") == "secret"
|
||||
|
||||
|
||||
def test_default_download_backend_is_python(manager):
|
||||
assert manager.get("download_backend") == "python"
|
||||
assert manager.get("aria2c_path") == ""
|
||||
|
||||
|
||||
def _create_manager_with_settings(
|
||||
tmp_path, monkeypatch, initial_settings, *, save_spy=None
|
||||
):
|
||||
@@ -327,6 +332,43 @@ def test_auto_set_default_roots_keeps_valid_values(manager):
|
||||
assert manager.get("default_embedding_root") == "/embeddings"
|
||||
|
||||
|
||||
def test_auto_set_default_roots_keeps_valid_extra_values(manager):
|
||||
manager.settings["default_lora_root"] = "/extra-loras"
|
||||
manager.settings["default_checkpoint_root"] = "/extra-checkpoints"
|
||||
manager.settings["default_embedding_root"] = "/extra-embeddings"
|
||||
manager.settings["default_unet_root"] = "/extra-unet"
|
||||
|
||||
manager.settings["folder_paths"] = {
|
||||
"loras": ["/loras"],
|
||||
"checkpoints": ["/checkpoints"],
|
||||
"unet": ["/unet"],
|
||||
"embeddings": ["/embeddings"],
|
||||
}
|
||||
manager.settings["extra_folder_paths"] = {
|
||||
"loras": ["/extra-loras"],
|
||||
"checkpoints": ["/extra-checkpoints"],
|
||||
"unet": ["/extra-unet"],
|
||||
"embeddings": ["/extra-embeddings"],
|
||||
}
|
||||
|
||||
manager._auto_set_default_roots()
|
||||
|
||||
assert manager.get("default_lora_root") == "/extra-loras"
|
||||
assert manager.get("default_checkpoint_root") == "/extra-checkpoints"
|
||||
assert manager.get("default_unet_root") == "/extra-unet"
|
||||
assert manager.get("default_embedding_root") == "/extra-embeddings"
|
||||
|
||||
|
||||
def test_auto_set_default_roots_falls_back_to_extra_when_primary_missing(manager):
|
||||
manager.settings["default_lora_root"] = ""
|
||||
manager.settings["folder_paths"] = {"loras": []}
|
||||
manager.settings["extra_folder_paths"] = {"loras": ["/extra-loras"]}
|
||||
|
||||
manager._auto_set_default_roots()
|
||||
|
||||
assert manager.get("default_lora_root") == "/extra-loras"
|
||||
|
||||
|
||||
def test_delete_setting(manager):
|
||||
manager.set("example", 1)
|
||||
manager.delete("example")
|
||||
|
||||
Reference in New Issue
Block a user