mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-03 07:51:16 -03:00
Compare commits
469 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63785f82b5 | ||
|
|
cf898da193 | ||
|
|
fe90f7f9b1 | ||
|
|
8b344ea39f | ||
|
|
8348a0cef8 | ||
|
|
7cf785b72f | ||
|
|
e8913f4481 | ||
|
|
f9c3d8dc97 | ||
|
|
09ca91fc0e | ||
|
|
16f5222efd | ||
|
|
28e7c04b37 | ||
|
|
28f99c46d3 | ||
|
|
205194f4e6 | ||
|
|
402d8b07cf | ||
|
|
3e303ab316 | ||
|
|
e9e8c31ad1 | ||
|
|
703a6a4ea0 | ||
|
|
283730cf38 | ||
|
|
20417797e8 | ||
|
|
004c69b9ef | ||
|
|
47fe2d3783 | ||
|
|
36ef840a22 | ||
|
|
09c2445ac9 | ||
|
|
8a6d23f9c7 | ||
|
|
3d207b6744 | ||
|
|
b3edda62ad | ||
|
|
a429e6b1c3 | ||
|
|
c1bf9c6221 | ||
|
|
75fffc1e25 | ||
|
|
f264bab65c | ||
|
|
154fcd803b | ||
|
|
4ef32d3a96 | ||
|
|
d2d109a69c | ||
|
|
3a2941d751 | ||
|
|
0ac10dfd42 | ||
|
|
9c95856b2f | ||
|
|
5ce4667d32 | ||
|
|
be53fda6df | ||
|
|
f48de05102 | ||
|
|
93ad81ed87 | ||
|
|
ea14d211be | ||
|
|
8052cefd46 | ||
|
|
845815b9b7 | ||
|
|
609dc5d783 | ||
|
|
7a71b34b54 | ||
|
|
71a459422f | ||
|
|
cd2628a0ee | ||
|
|
85da7175bc | ||
|
|
d3bf0a164b | ||
|
|
afb6ca1b8d | ||
|
|
94f43426d7 | ||
|
|
2b361f4f5d | ||
|
|
7438072f8c | ||
|
|
26c54fd358 | ||
|
|
7cb6b04c63 | ||
|
|
fc29cde82a | ||
|
|
559ca946dc | ||
|
|
2b8e7c7504 | ||
|
|
6816d75933 | ||
|
|
b58abbad7c | ||
|
|
999814ca87 | ||
|
|
3c2760a803 | ||
|
|
0edbd7bcca | ||
|
|
21e89fa7de | ||
|
|
968d6d1d1f | ||
|
|
cf0fd0e0ad | ||
|
|
16e5dcf7b2 | ||
|
|
ab6bb25d46 | ||
|
|
07f49559be | ||
|
|
b24b1a7e57 | ||
|
|
faf64f8986 | ||
|
|
a617487a43 | ||
|
|
3012a7aef3 | ||
|
|
499e19de34 | ||
|
|
9161762ca9 | ||
|
|
9bbd26efe6 | ||
|
|
258b2622d5 | ||
|
|
80ec9085dd | ||
|
|
c5c7373e10 | ||
|
|
b7721866e5 | ||
|
|
8314b9bedb | ||
|
|
75298a402f | ||
|
|
92b5efd414 | ||
|
|
33ee392b7b | ||
|
|
5237f8b7dc | ||
|
|
5107313fd1 | ||
|
|
95bbc66919 | ||
|
|
e268e59419 | ||
|
|
547e1f9498 | ||
|
|
bf32d8b6fd | ||
|
|
8299881024 | ||
|
|
da02268196 | ||
|
|
8c4b9a1e70 | ||
|
|
0906c484e9 | ||
|
|
4199c30fec | ||
|
|
4a8084cdbc | ||
|
|
6263e6848c | ||
|
|
58c266ad07 | ||
|
|
2939813e1a | ||
|
|
a9e5ee7e79 | ||
|
|
a17b0e9901 | ||
|
|
8f23d966bf | ||
|
|
7a76fc72d0 | ||
|
|
518a4dd5ee | ||
|
|
2b6d4e5d8b | ||
|
|
1f4edbeb9d | ||
|
|
a256558a0e | ||
|
|
818b9113f0 | ||
|
|
6a4fd020dc | ||
|
|
7a23040452 | ||
|
|
138024aefe | ||
|
|
a19ddc14f6 | ||
|
|
7001ced694 | ||
|
|
a5c861646c | ||
|
|
3e0bb73793 | ||
|
|
ac51f6a2f6 | ||
|
|
bef222c77d | ||
|
|
7cd6a53447 | ||
|
|
6850b35770 | ||
|
|
237a015cde | ||
|
|
1ae2778baa | ||
|
|
84fcdb5f20 | ||
|
|
8a0b368b44 | ||
|
|
3990535505 | ||
|
|
3e961a9860 | ||
|
|
d6669f1d04 | ||
|
|
519bafebc8 | ||
|
|
d87863b423 | ||
|
|
84e9fe2dfb | ||
|
|
46cbcf94c8 | ||
|
|
05f3018495 | ||
|
|
f565cc35ca | ||
|
|
dd1cdce16d | ||
|
|
a9e0e7dc8d | ||
|
|
b302d1db7d | ||
|
|
7cbddd9cf7 | ||
|
|
cb8c699224 | ||
|
|
451f74b874 | ||
|
|
a1d248baa6 | ||
|
|
18577fa336 | ||
|
|
5797ce9408 | ||
|
|
826f06255a | ||
|
|
84e16b5c5b | ||
|
|
eb22054580 | ||
|
|
08afb05ece | ||
|
|
f51f125cf1 | ||
|
|
24b2078f21 | ||
|
|
130fb5d2d5 | ||
|
|
23c6863a3a | ||
|
|
c0e2578640 | ||
|
|
e3c812367e | ||
|
|
4d239008a6 | ||
|
|
00177a06d0 | ||
|
|
568daa351e | ||
|
|
5a4664fa12 | ||
|
|
dd5b213adc | ||
|
|
d9ee9b3155 | ||
|
|
01dac57c35 | ||
|
|
7f92d09239 | ||
|
|
62f9e3f44a | ||
|
|
e55895786d | ||
|
|
82b77bf593 | ||
|
|
1beef5dea9 | ||
|
|
c8beaa64e1 | ||
|
|
fb443ed6ae | ||
|
|
151a467598 | ||
|
|
98e1d168b0 | ||
|
|
716f18e0ed | ||
|
|
b060dc99fc | ||
|
|
54bcdfab38 | ||
|
|
2e7532eecc | ||
|
|
7e5e3b1ec7 | ||
|
|
df67bd396a | ||
|
|
dd5d9cfcb2 | ||
|
|
d9fd60bec1 | ||
|
|
b633b22779 | ||
|
|
1ffa543160 | ||
|
|
cdc940586e | ||
|
|
ccf1c6f2ae | ||
|
|
bfe7b5e1c7 | ||
|
|
85c020cd12 | ||
|
|
1b202f8ec7 | ||
|
|
d02a0611d3 | ||
|
|
92166a161a | ||
|
|
b509f27cb7 | ||
|
|
5c2ef48917 | ||
|
|
ad2bd82c67 | ||
|
|
17ba350153 | ||
|
|
60175334b5 | ||
|
|
f65a01df00 | ||
|
|
430e24d70b | ||
|
|
14f0c48fdd | ||
|
|
34791c2ad7 | ||
|
|
3f6824eef6 | ||
|
|
3919dfa3f4 | ||
|
|
7124b5293f | ||
|
|
d2a04f8993 | ||
|
|
7027a7c270 | ||
|
|
0a1d7dfd4c | ||
|
|
3962b1a96d | ||
|
|
8b856276bf | ||
|
|
c97c802956 | ||
|
|
24e2909627 | ||
|
|
b768f1368f | ||
|
|
37ccd29fc0 | ||
|
|
7416080cfb | ||
|
|
26be187d42 | ||
|
|
d7caa1fa47 | ||
|
|
2629fcce23 | ||
|
|
438e7d07b9 | ||
|
|
e9932ea870 | ||
|
|
5dd8b96422 | ||
|
|
5e1cf68bbd | ||
|
|
1044fa3c83 | ||
|
|
397892bb7f | ||
|
|
f105500740 | ||
|
|
806555cf06 | ||
|
|
5cd7204101 | ||
|
|
3b602a3698 | ||
|
|
15dfaed462 | ||
|
|
0e51851025 | ||
|
|
0d0f4defca | ||
|
|
818fa34a48 | ||
|
|
78303b2a5e | ||
|
|
9ce56dd40c | ||
|
|
4e3ede23b7 | ||
|
|
33e5f3d85d | ||
|
|
031d5e4f40 | ||
|
|
4ff5774e34 | ||
|
|
94e1a8ac7b | ||
|
|
cc20d3b992 | ||
|
|
a74cbe7aa2 | ||
|
|
94edfaa190 | ||
|
|
31c54ff068 | ||
|
|
21872a8e9e | ||
|
|
612612f1c7 | ||
|
|
ff240db5b1 | ||
|
|
bcfed4b874 | ||
|
|
1352c6ecbe | ||
|
|
30b01b8a92 | ||
|
|
a105cb322b | ||
|
|
3bf396d003 | ||
|
|
60cfb3b8e0 | ||
|
|
6763abb83c | ||
|
|
5c53968caa | ||
|
|
b4f7dd75af | ||
|
|
86118d0654 | ||
|
|
df1410535e | ||
|
|
75f74d54d8 | ||
|
|
ab6100f596 | ||
|
|
5d3ab3bbf8 | ||
|
|
d9dc0dba8d | ||
|
|
3631c5eb10 | ||
|
|
6d5b4b7312 | ||
|
|
7803bd542d | ||
|
|
f0a86dbbc0 | ||
|
|
682e964f89 | ||
|
|
908464bc0a | ||
|
|
0ffee3a854 | ||
|
|
8aa9739c44 | ||
|
|
50739bbb43 | ||
|
|
e849303763 | ||
|
|
241b2e15d2 | ||
|
|
88da754504 | ||
|
|
b4a706651f | ||
|
|
ff7cc6d9bb | ||
|
|
454210a47c | ||
|
|
2d7c404ebb | ||
|
|
e23d803ecf | ||
|
|
0cc640cfaa | ||
|
|
2ac0eb0f9d | ||
|
|
f028625ce9 | ||
|
|
06acc7f576 | ||
|
|
d324b57274 | ||
|
|
502b7eab31 | ||
|
|
be75ad930e | ||
|
|
763c4f4dad | ||
|
|
d32c492bdb | ||
|
|
5dcfde36ea | ||
|
|
1d035361a4 | ||
|
|
25605c5e78 | ||
|
|
f3268a6179 | ||
|
|
055e94d77b | ||
|
|
47fcd530a0 | ||
|
|
3c32b9e088 | ||
|
|
ffe0670a27 | ||
|
|
cc147a1795 | ||
|
|
e81409bea4 | ||
|
|
b31fae4e51 | ||
|
|
c6e5467907 | ||
|
|
df0e5797d0 | ||
|
|
ebdbb36271 | ||
|
|
2eef629821 | ||
|
|
658a04736d | ||
|
|
ef7f677933 | ||
|
|
63f0942452 | ||
|
|
a1dff6dd47 | ||
|
|
7fa40023b0 | ||
|
|
3c8acdb65e | ||
|
|
1e9a7812d6 | ||
|
|
37f0e8f213 | ||
|
|
ecf7ea21e4 | ||
|
|
79dd9a1b29 | ||
|
|
ef4923fd94 | ||
|
|
1eeba666f5 | ||
|
|
89e26d9292 | ||
|
|
fc19a145ff | ||
|
|
34f03d6495 | ||
|
|
9443175abc | ||
|
|
dc5072628f | ||
|
|
ff4b8ec849 | ||
|
|
7ab271c752 | ||
|
|
5a7f4dc88b | ||
|
|
761108bfd1 | ||
|
|
24dd3a777c | ||
|
|
1c530ea013 | ||
|
|
0ced53c059 | ||
|
|
67ad68a23f | ||
|
|
d9ec9c512e | ||
|
|
0bcd8e09a9 | ||
|
|
fa049a28c8 | ||
|
|
89fd2b43d6 | ||
|
|
c53f44e7ef | ||
|
|
ae7bfdb517 | ||
|
|
68bf8442eb | ||
|
|
605fbf4117 | ||
|
|
406d5fea6a | ||
|
|
af2146f96c | ||
|
|
bdc8dec860 | ||
|
|
c4fa1631ee | ||
|
|
506d763dc2 | ||
|
|
a2cd09b619 | ||
|
|
cdd77029b6 | ||
|
|
439679e15f | ||
|
|
2640258902 | ||
|
|
b910388d54 | ||
|
|
083de395b1 | ||
|
|
4514ca94b7 | ||
|
|
62247bdd87 | ||
|
|
6d0d9600a7 | ||
|
|
70cd3f4e1b | ||
|
|
a95c518b30 | ||
|
|
ba1800095e | ||
|
|
39c083db79 | ||
|
|
55e9e4bb6f | ||
|
|
0253d001e6 | ||
|
|
9998da3241 | ||
|
|
6666a72775 | ||
|
|
5f1bd894b9 | ||
|
|
1817142a7b | ||
|
|
25fa175aa2 | ||
|
|
39643eb2bc | ||
|
|
4ac78f8aa8 | ||
|
|
0bcca0ba68 | ||
|
|
72f8e0d1be | ||
|
|
85b6c91192 | ||
|
|
908016cbd6 | ||
|
|
a5ac9cf81b | ||
|
|
32875042bd | ||
|
|
51fe7aa07e | ||
|
|
db4726a961 | ||
|
|
e13d70248a | ||
|
|
1c4919a3e8 | ||
|
|
18ddadc9ec | ||
|
|
b6dd6938b0 | ||
|
|
b711ac468a | ||
|
|
727d0ef043 | ||
|
|
9344d86332 | ||
|
|
d36b16c213 | ||
|
|
33a7f07558 | ||
|
|
4f599aeced | ||
|
|
30db8c3d1d | ||
|
|
05636712f0 | ||
|
|
d8e5fe1247 | ||
|
|
3e9210394a | ||
|
|
4dd2c0526f | ||
|
|
9bdb337962 | ||
|
|
f93baf5fc0 | ||
|
|
14cb7fec47 | ||
|
|
f3b3e0adad | ||
|
|
ba3f15dbc6 | ||
|
|
8dc2a2f76b | ||
|
|
316f17dd46 | ||
|
|
3dc10b1404 | ||
|
|
331889d872 | ||
|
|
06f1a82d4c | ||
|
|
267082c712 | ||
|
|
a4cb51e96c | ||
|
|
ca44c367b3 | ||
|
|
301ab14781 | ||
|
|
2626dbab8e | ||
|
|
12bbb0572d | ||
|
|
00f5c1e887 | ||
|
|
89b1675ec7 | ||
|
|
dcc7bd33b5 | ||
|
|
e5152108ba | ||
|
|
1ed5eef985 | ||
|
|
a82f89d14a | ||
|
|
16e30ea689 | ||
|
|
ad3bdddb72 | ||
|
|
9121306b06 | ||
|
|
ca0baf9462 | ||
|
|
20e50156a2 | ||
|
|
0b66bf5479 | ||
|
|
1e8aca4787 | ||
|
|
76ee59cdb9 | ||
|
|
a5191414cc | ||
|
|
5b065b47d4 | ||
|
|
ceeab0c998 | ||
|
|
3b001a6cd8 | ||
|
|
95e5bc26d1 | ||
|
|
de3d0571f8 | ||
|
|
6f2a01dc86 | ||
|
|
c5c1b8fd2a | ||
|
|
e97648c70b | ||
|
|
8b85e083e2 | ||
|
|
9112cd3b62 | ||
|
|
7df4e8d037 | ||
|
|
4000b7f7e7 | ||
|
|
76c15105e6 | ||
|
|
b11c90e19b | ||
|
|
9f5d2d0c18 | ||
|
|
a0dc5229f4 | ||
|
|
61c31ecbd0 | ||
|
|
1ae1b0d607 | ||
|
|
8dd849892d | ||
|
|
03e1fa75c5 | ||
|
|
fefcaa4a45 | ||
|
|
701a6a6c44 | ||
|
|
0ef414d17e | ||
|
|
75dccaef87 | ||
|
|
7e87ec9521 | ||
|
|
46522edb1b | ||
|
|
2dae4c1291 | ||
|
|
a32325402e | ||
|
|
70c150bd80 | ||
|
|
9e81c33f8a | ||
|
|
22c0dbd734 | ||
|
|
d0c58472be | ||
|
|
b3c530bf36 | ||
|
|
05ebd7493d | ||
|
|
90986bd795 | ||
|
|
b5a0725d2c | ||
|
|
ef38bda04f | ||
|
|
58713ea6e0 | ||
|
|
8b91920058 | ||
|
|
ee466113d5 | ||
|
|
f86651652c | ||
|
|
c89d4dae85 | ||
|
|
55a18d401b | ||
|
|
7570936c75 | ||
|
|
4fcf641d57 | ||
|
|
5c29e26c4e | ||
|
|
ee765a6d22 | ||
|
|
c02f603ed2 | ||
|
|
ee84b30023 | ||
|
|
97979d9e7c | ||
|
|
cda271890a | ||
|
|
2fbe6c8843 | ||
|
|
4fb07370dd | ||
|
|
43f6bfab36 | ||
|
|
a802a89ff9 | ||
|
|
343dd91e4b | ||
|
|
3756f88368 | ||
|
|
acc625ead3 | ||
|
|
f402505f97 | ||
|
|
4d8113464c | ||
|
|
1ed503a6b5 | ||
|
|
d67914e095 |
69
.agents/skills/lora-manager-runtime-context/SKILL.md
Normal file
69
.agents/skills/lora-manager-runtime-context/SKILL.md
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: lora-manager-runtime-context
|
||||
description: Inspect ComfyUI LoRA Manager runtime configuration and local diagnostic state. Use when debugging LoRA Manager issues that require locating or reading settings.json, active library paths, model metadata JSON sidecars, recipe metadata JSON files, example image folders, SQLite caches, symlink maps, download history, aria2 state, or other cache files under the LoRA Manager user config directory.
|
||||
---
|
||||
|
||||
# LoRA Manager Runtime Context
|
||||
|
||||
## Core Rules
|
||||
|
||||
- Treat runtime state as local user data. Prefer read-only inspection unless the user explicitly asks for mutation.
|
||||
- Never print secret-like settings values. Redact keys containing `key`, `token`, `secret`, `password`, `auth`, or `credential`, including `civitai_api_key`.
|
||||
- Resolve paths from the runtime configuration before guessing. In this environment the settings file is normally `/home/miao/.config/ComfyUI-LoRA-Manager/settings.json`, but portable settings can override this through the repository `settings.json`.
|
||||
- Use the active library when selecting per-library caches and paths. Read `active_library` from settings; fall back to `default` if missing.
|
||||
- Normalize and expand `~` before comparing paths. Symlinks are common in this repo.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Use the bundled helper for a safe first pass:
|
||||
|
||||
```bash
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py summary
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py caches
|
||||
```
|
||||
|
||||
The script redacts sensitive settings, opens SQLite databases read-only, and reports inaccessible or locked databases as warnings.
|
||||
|
||||
For focused checks:
|
||||
|
||||
```bash
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py recipes
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py model --path /path/to/model.safetensors
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py sqlite --db /path/to/cache.sqlite --limit 3
|
||||
```
|
||||
|
||||
## Runtime Path Rules
|
||||
|
||||
- Settings directory: use `py/utils/settings_paths.py`. Default platform path is `platformdirs.user_config_dir("ComfyUI-LoRA-Manager", appauthor=False)`.
|
||||
- Settings file: `<settings_dir>/settings.json`.
|
||||
- Cache root: `<settings_dir>/cache`.
|
||||
- Canonical cache files:
|
||||
- Model cache: `cache/model/<active_library>.sqlite`.
|
||||
- Recipe cache: `cache/recipe/<active_library>.sqlite`.
|
||||
- Model update cache: `cache/model_update/<active_library>.sqlite`.
|
||||
- Recipe FTS: `cache/fts/recipe_fts.sqlite`.
|
||||
- Tag FTS: `cache/fts/tag_fts.sqlite`.
|
||||
- Symlink map: `cache/symlink/symlink_map.json`.
|
||||
- Download history: `cache/download_history/downloaded_versions.sqlite`.
|
||||
- aria2 state: `cache/aria2/downloads.json`.
|
||||
- Legacy cache locations may exist; prefer canonical paths unless diagnosing migrations.
|
||||
|
||||
## Data Location Rules
|
||||
|
||||
- Model roots come from `settings.folder_paths` and the active library payload under `settings.libraries[active_library]`.
|
||||
- Model metadata JSON sidecars live next to the model file as `<model basename>.metadata.json`.
|
||||
- Recipes root is `settings.recipes_path` when it is a non-empty string. If empty, use the first configured LoRA root plus `/recipes`.
|
||||
- Recipe JSON files are named `*.recipe.json` under the recipes root and may be nested in folders.
|
||||
- Example image root is `settings.example_images_path`.
|
||||
- If multiple libraries are configured, example images are stored under `<example_images_path>/<sanitized_library>/<sha256>/`; otherwise they are under `<example_images_path>/<sha256>/`.
|
||||
|
||||
## Useful Cache Tables
|
||||
|
||||
- Model cache: `models`, `model_tags`, `hash_index`, `excluded_models`.
|
||||
- Recipe cache: `recipes`, `cache_metadata`.
|
||||
- Model update cache: `model_update_status`, `model_update_versions`.
|
||||
- Tag FTS cache: `tags`, `fts_metadata`, plus FTS internal tables.
|
||||
- Recipe FTS cache: `recipe_rowid`, `fts_metadata`, plus FTS internal tables.
|
||||
- Download history: `downloaded_model_versions`.
|
||||
|
||||
Prefer querying only counts, schema, and a few sample rows unless the user asks for full output.
|
||||
@@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "LoRA Manager Runtime Context"
|
||||
short_description: "Inspect LoRA Manager runtime state"
|
||||
default_prompt: "Use $lora-manager-runtime-context to inspect LoRA Manager settings, metadata paths, and caches for debugging."
|
||||
381
.agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py
Executable file
381
.agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py
Executable file
@@ -0,0 +1,381 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
SECRET_PATTERN = re.compile(r"(key|token|secret|password|auth|credential)", re.IGNORECASE)
|
||||
APP_NAME = "ComfyUI-LoRA-Manager"
|
||||
CACHE_SQLITE = {
|
||||
"model": ("model", "{library}.sqlite"),
|
||||
"recipe": ("recipe", "{library}.sqlite"),
|
||||
"model_update": ("model_update", "{library}.sqlite"),
|
||||
"recipe_fts": ("fts", "recipe_fts.sqlite"),
|
||||
"tag_fts": ("fts", "tag_fts.sqlite"),
|
||||
"download_history": ("download_history", "downloaded_versions.sqlite"),
|
||||
}
|
||||
CACHE_JSON = {
|
||||
"symlink": ("symlink", "symlink_map.json"),
|
||||
"aria2": ("aria2", "downloads.json"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Inspect LoRA Manager runtime state read-only.")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("summary", help="Print redacted settings and resolved paths.")
|
||||
subparsers.add_parser("caches", help="Print cache paths and SQLite table summaries.")
|
||||
subparsers.add_parser("recipes", help="Print resolved recipes root and recipe JSON count.")
|
||||
|
||||
model_parser = subparsers.add_parser("model", help="Inspect a model metadata sidecar path.")
|
||||
model_parser.add_argument("--path", required=True, help="Path to a model file or metadata JSON file.")
|
||||
|
||||
sqlite_parser = subparsers.add_parser("sqlite", help="Inspect a SQLite database read-only.")
|
||||
sqlite_parser.add_argument("--db", required=True, help="Path to the SQLite database.")
|
||||
sqlite_parser.add_argument("--limit", type=int, default=3, help="Rows to sample from each user table.")
|
||||
|
||||
args = parser.parse_args()
|
||||
context = build_context()
|
||||
|
||||
if args.command == "summary":
|
||||
print_json(summary_payload(context))
|
||||
elif args.command == "caches":
|
||||
print_json(caches_payload(context))
|
||||
elif args.command == "recipes":
|
||||
print_json(recipes_payload(context))
|
||||
elif args.command == "model":
|
||||
print_json(model_payload(args.path))
|
||||
elif args.command == "sqlite":
|
||||
print_json(sqlite_payload(Path(args.db).expanduser(), args.limit))
|
||||
return 0
|
||||
|
||||
|
||||
def build_context() -> dict[str, Any]:
|
||||
settings_path = resolve_settings_path()
|
||||
settings = load_json(settings_path)
|
||||
settings_dir = settings_path.parent
|
||||
active_library = settings.get("active_library") or "default"
|
||||
safe_library = sanitize_library_name(str(active_library))
|
||||
cache_root = settings_dir / "cache"
|
||||
return {
|
||||
"settings_path": str(settings_path),
|
||||
"settings_dir": str(settings_dir),
|
||||
"settings": settings,
|
||||
"active_library": active_library,
|
||||
"safe_library": safe_library,
|
||||
"cache_root": str(cache_root),
|
||||
"cache_paths": resolve_cache_paths(cache_root, safe_library),
|
||||
}
|
||||
|
||||
|
||||
def resolve_settings_path() -> Path:
|
||||
repo_root = find_repo_root()
|
||||
portable = repo_root / "settings.json"
|
||||
if portable.exists():
|
||||
payload = load_json(portable)
|
||||
if isinstance(payload, dict) and payload.get("use_portable_settings") is True:
|
||||
return portable
|
||||
|
||||
config_home = os.environ.get("XDG_CONFIG_HOME")
|
||||
if config_home:
|
||||
return Path(config_home).expanduser() / APP_NAME / "settings.json"
|
||||
return Path.home() / ".config" / APP_NAME / "settings.json"
|
||||
|
||||
|
||||
def find_repo_root() -> Path:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "py").is_dir() and (parent / "standalone.py").exists():
|
||||
return parent
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
def load_json(path: Path) -> dict[str, Any]:
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError as exc:
|
||||
return {"_error": f"invalid JSON: {exc}"}
|
||||
except OSError as exc:
|
||||
return {"_error": f"unreadable: {exc}"}
|
||||
return payload if isinstance(payload, dict) else {"_error": "JSON root is not an object"}
|
||||
|
||||
|
||||
def resolve_cache_paths(cache_root: Path, library: str) -> dict[str, str]:
|
||||
paths: dict[str, str] = {}
|
||||
for name, (subdir, filename) in CACHE_SQLITE.items():
|
||||
paths[name] = str(cache_root / subdir / filename.format(library=library))
|
||||
for name, (subdir, filename) in CACHE_JSON.items():
|
||||
paths[name] = str(cache_root / subdir / filename)
|
||||
return paths
|
||||
|
||||
|
||||
def summary_payload(context: dict[str, Any]) -> dict[str, Any]:
|
||||
settings = context["settings"]
|
||||
return {
|
||||
"settings_path": context["settings_path"],
|
||||
"settings_dir": context["settings_dir"],
|
||||
"active_library": context["active_library"],
|
||||
"settings": redact(settings),
|
||||
"model_roots": model_roots(settings, context["active_library"]),
|
||||
"recipes_root": str(resolve_recipes_root(settings, context["active_library"]) or ""),
|
||||
"example_images": example_images_payload(settings, context["active_library"]),
|
||||
"cache_root": context["cache_root"],
|
||||
"cache_paths": context["cache_paths"],
|
||||
}
|
||||
|
||||
|
||||
def caches_payload(context: dict[str, Any]) -> dict[str, Any]:
|
||||
caches: dict[str, Any] = {}
|
||||
for name, path_string in context["cache_paths"].items():
|
||||
path = Path(path_string)
|
||||
item: dict[str, Any] = {
|
||||
"path": str(path),
|
||||
"exists": path.exists(),
|
||||
"size": path.stat().st_size if path.exists() else None,
|
||||
}
|
||||
if path.suffix == ".sqlite":
|
||||
item["sqlite"] = sqlite_payload(path, limit=0)
|
||||
elif path.suffix == ".json":
|
||||
item["json"] = json_file_summary(path)
|
||||
caches[name] = item
|
||||
return {"active_library": context["active_library"], "caches": caches}
|
||||
|
||||
|
||||
def recipes_payload(context: dict[str, Any]) -> dict[str, Any]:
|
||||
root = resolve_recipes_root(context["settings"], context["active_library"])
|
||||
files: list[str] = []
|
||||
if root and root.exists():
|
||||
files = [str(path) for path in sorted(root.rglob("*.recipe.json"))[:20]]
|
||||
return {
|
||||
"recipes_root": str(root or ""),
|
||||
"exists": bool(root and root.exists()),
|
||||
"recipe_json_count": count_recipe_files(root),
|
||||
"sample_recipe_json": files,
|
||||
"recipe_cache": context["cache_paths"].get("recipe"),
|
||||
}
|
||||
|
||||
|
||||
def model_payload(raw_path: str) -> dict[str, Any]:
|
||||
path = Path(raw_path).expanduser()
|
||||
metadata_path = path if path.name.endswith(".metadata.json") else path.with_suffix(".metadata.json")
|
||||
payload = {
|
||||
"input_path": str(path),
|
||||
"metadata_path": str(metadata_path),
|
||||
"model_exists": path.exists(),
|
||||
"metadata_exists": metadata_path.exists(),
|
||||
}
|
||||
if metadata_path.exists():
|
||||
data = load_json(metadata_path)
|
||||
payload["metadata_summary"] = redact(summarize_value(data))
|
||||
return payload
|
||||
|
||||
|
||||
def sqlite_payload(path: Path, limit: int = 3, allow_copy: bool = True) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {"path": str(path), "exists": path.exists(), "tables": {}}
|
||||
if not path.exists():
|
||||
return result
|
||||
try:
|
||||
conn = connect_sqlite_readonly(path)
|
||||
except sqlite3.Error as exc:
|
||||
result["error"] = str(exc)
|
||||
return result
|
||||
try:
|
||||
table_rows = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||
).fetchall()
|
||||
for table_row in table_rows:
|
||||
table = table_row["name"]
|
||||
columns = [
|
||||
row["name"]
|
||||
for row in conn.execute(f"PRAGMA table_info({quote_identifier(table)})").fetchall()
|
||||
]
|
||||
table_info: dict[str, Any] = {"columns": columns}
|
||||
try:
|
||||
table_info["count"] = conn.execute(
|
||||
f"SELECT COUNT(*) FROM {quote_identifier(table)}"
|
||||
).fetchone()[0]
|
||||
except sqlite3.Error as exc:
|
||||
table_info["count_error"] = str(exc)
|
||||
if limit > 0 and columns and not is_internal_sqlite_table(table):
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT * FROM {quote_identifier(table)} LIMIT ?", (limit,)
|
||||
).fetchall()
|
||||
table_info["sample"] = [redact(dict(row)) for row in rows]
|
||||
except sqlite3.Error as exc:
|
||||
table_info["sample_error"] = str(exc)
|
||||
result["tables"][table] = table_info
|
||||
except sqlite3.Error as exc:
|
||||
fallback = sqlite_copy_payload(path, limit, str(exc)) if allow_copy else None
|
||||
if fallback is not None:
|
||||
result.update(fallback)
|
||||
else:
|
||||
result["error"] = str(exc)
|
||||
finally:
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
|
||||
def connect_sqlite_readonly(path: Path) -> sqlite3.Connection:
|
||||
errors: list[str] = []
|
||||
for query in ("mode=ro", "mode=ro&immutable=1"):
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{path}?{query}", uri=True)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
except sqlite3.Error as exc:
|
||||
errors.append(f"{query}: {exc}")
|
||||
raise sqlite3.OperationalError("; ".join(errors))
|
||||
|
||||
|
||||
def sqlite_copy_payload(path: Path, limit: int, original_error: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="lm-cache-inspect-") as temp_dir:
|
||||
copy_path = Path(temp_dir) / path.name
|
||||
shutil.copy2(path, copy_path)
|
||||
payload = sqlite_payload(copy_path, limit, allow_copy=False)
|
||||
payload["path"] = str(path)
|
||||
payload["inspected_copy"] = True
|
||||
payload["original_error"] = original_error
|
||||
return payload
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def json_file_summary(path: Path) -> dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {"exists": False}
|
||||
data = load_json(path)
|
||||
return {"exists": True, "summary": redact(summarize_value(data))}
|
||||
|
||||
|
||||
def model_roots(settings: dict[str, Any], active_library: str) -> dict[str, list[str]]:
|
||||
roots: dict[str, list[str]] = {}
|
||||
sources = [settings]
|
||||
library = settings.get("libraries", {}).get(active_library)
|
||||
if isinstance(library, dict):
|
||||
sources.insert(0, library)
|
||||
for source in sources:
|
||||
folder_paths = source.get("folder_paths")
|
||||
if isinstance(folder_paths, dict):
|
||||
for key, value in folder_paths.items():
|
||||
roots.setdefault(key, []).extend(normalize_path_list(value))
|
||||
for default_key, folder_key in (
|
||||
("default_lora_root", "loras"),
|
||||
("default_checkpoint_root", "checkpoints"),
|
||||
("default_embedding_root", "embeddings"),
|
||||
("default_unet_root", "unet"),
|
||||
):
|
||||
value = settings.get(default_key)
|
||||
if isinstance(value, str) and value:
|
||||
roots.setdefault(folder_key, []).append(expand_path(value))
|
||||
return {key: dedupe(values) for key, values in roots.items()}
|
||||
|
||||
|
||||
def resolve_recipes_root(settings: dict[str, Any], active_library: str) -> Path | None:
|
||||
recipes_path = settings.get("recipes_path")
|
||||
library = settings.get("libraries", {}).get(active_library)
|
||||
if isinstance(library, dict) and isinstance(library.get("recipes_path"), str):
|
||||
recipes_path = library["recipes_path"] or recipes_path
|
||||
if isinstance(recipes_path, str) and recipes_path.strip():
|
||||
return Path(expand_path(recipes_path.strip()))
|
||||
lora_roots = model_roots(settings, active_library).get("loras") or []
|
||||
return Path(lora_roots[0]) / "recipes" if lora_roots else None
|
||||
|
||||
|
||||
def example_images_payload(settings: dict[str, Any], active_library: str) -> dict[str, Any]:
|
||||
root = settings.get("example_images_path") or ""
|
||||
libraries = settings.get("libraries")
|
||||
library_count = len(libraries) if isinstance(libraries, dict) else 0
|
||||
scoped = library_count > 1
|
||||
root_path = Path(expand_path(root)) if isinstance(root, str) and root else None
|
||||
library_root = root_path / sanitize_library_name(active_library) if root_path and scoped else root_path
|
||||
return {
|
||||
"root": str(root_path or ""),
|
||||
"uses_library_scoped_folders": scoped,
|
||||
"library_root": str(library_root or ""),
|
||||
}
|
||||
|
||||
|
||||
def count_recipe_files(root: Path | None) -> int:
|
||||
if not root or not root.exists():
|
||||
return 0
|
||||
return sum(1 for _ in root.rglob("*.recipe.json"))
|
||||
|
||||
|
||||
def normalize_path_list(value: Any) -> list[str]:
|
||||
if isinstance(value, str):
|
||||
return [expand_path(value)] if value else []
|
||||
if isinstance(value, list):
|
||||
return [expand_path(item) for item in value if isinstance(item, str) and item]
|
||||
return []
|
||||
|
||||
|
||||
def expand_path(value: str) -> str:
|
||||
return str(Path(value).expanduser().resolve(strict=False))
|
||||
|
||||
|
||||
def sanitize_library_name(name: str) -> str:
|
||||
safe = re.sub(r"[^A-Za-z0-9_.-]", "_", name or "default")
|
||||
return safe or "default"
|
||||
|
||||
|
||||
def dedupe(values: list[str]) -> list[str]:
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
if value not in seen:
|
||||
result.append(value)
|
||||
seen.add(value)
|
||||
return result
|
||||
|
||||
|
||||
def redact(value: Any, key: str = "") -> Any:
|
||||
if key and SECRET_PATTERN.search(key):
|
||||
return "<redacted>"
|
||||
if isinstance(value, dict):
|
||||
return {str(k): redact(v, str(k)) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [redact(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def summarize_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {key: summarize_value(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return {
|
||||
"type": "array",
|
||||
"length": len(value),
|
||||
"first": summarize_value(value[0]) if value else None,
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
def quote_identifier(identifier: str) -> str:
|
||||
return '"' + identifier.replace('"', '""') + '"'
|
||||
|
||||
|
||||
def is_internal_sqlite_table(table: str) -> bool:
|
||||
return table.startswith("sqlite_") or table.endswith(("_data", "_idx", "_docsize", "_config", "_content"))
|
||||
|
||||
|
||||
def print_json(payload: Any) -> None:
|
||||
json.dump(payload, sys.stdout, indent=2, ensure_ascii=False)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
3
.github/ISSUE_TEMPLATE/feature_request.md
vendored
3
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -13,8 +13,5 @@ A clear and concise description of what the problem is. Ex. I'm always frustrate
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -7,13 +7,24 @@ py/run_test.py
|
||||
.vscode/
|
||||
cache/
|
||||
civitai/
|
||||
stats/
|
||||
wildcards/
|
||||
backups/
|
||||
logs/
|
||||
node_modules/
|
||||
coverage/
|
||||
.coverage
|
||||
model_cache/
|
||||
|
||||
# agent
|
||||
# agent / dev tooling
|
||||
.opencode/
|
||||
.claude/
|
||||
.sisyphus/
|
||||
.codex
|
||||
.omo
|
||||
reasonix.toml
|
||||
.reasonix/
|
||||
.codegraph/
|
||||
|
||||
# Vue widgets development cache (but keep build output)
|
||||
vue-widgets/node_modules/
|
||||
@@ -22,3 +33,6 @@ vue-widgets/dist/
|
||||
|
||||
# Hypothesis test cache
|
||||
.hypothesis/
|
||||
|
||||
# Working/research notes (not committed)
|
||||
.docs/
|
||||
|
||||
181
.omo/plans/embeddings-hybrid-approach.md
Normal file
181
.omo/plans/embeddings-hybrid-approach.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# Embeddings Usage Tracking — Hybrid Approach (Plan C)
|
||||
|
||||
> **Status**: Reference document for future implementation
|
||||
> **Current implementation**: Plan A (prompt text parsing only, see `usage_stats.py:_process_embeddings`)
|
||||
> **Next step**: Add Plan B as a supplement when edge-case coverage is needed
|
||||
|
||||
## Problem
|
||||
|
||||
Embeddings in ComfyUI are not loaded through dedicated ComfyUI nodes like LoRAs or
|
||||
Checkpoints. They are resolved during CLIP tokenization when the prompt text contains
|
||||
`embedding:<name>` syntax (see `comfy/sd1_clip.py:SDTokenizer.tokenize_with_weights`).
|
||||
|
||||
This means the existing metadata_collector hook (which intercepts node execution via
|
||||
`_map_node_over_list`) cannot capture embeddings the same way it captures LoRAs and
|
||||
checkpoints — there is no "EmbeddingLoader" node to intercept.
|
||||
|
||||
## Solution Architecture
|
||||
|
||||
The hybrid approach combines **two complementary mechanisms** to capture embedding
|
||||
usage from all possible paths.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Plan A (已实现) │
|
||||
│ │
|
||||
│ MetadataRegistry.prompt_metadata["prompts"] │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ _process_embeddings() │
|
||||
│ │ │
|
||||
│ ├─ Iterate all prompt node texts │
|
||||
│ ├─ regex extract "embedding:<name>" │
|
||||
│ ├─ resolve name → sha256 via EmbeddingScanner │
|
||||
│ └─ UsageStats.stats["embeddings"][sha256]++ │
|
||||
│ │
|
||||
│ Coverage: ~95% — all CLIPTextEncode/Flux/etc nodes │
|
||||
│ │
|
||||
│ Gap: Custom nodes that load embeddings programmatically │
|
||||
│ without putting embedding:name in prompt text │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
|
||||
+
|
||||
↓ (future: enable Plan B when needed)
|
||||
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Plan B (未来 — monkey-patch) │
|
||||
│ │
|
||||
│ comfy/sd1_clip.py:load_embed() │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ Monkey-patch intercepts EVERY embedding file load │
|
||||
│ │ │
|
||||
│ ├─ Records embedding_name + success/failure │
|
||||
│ ├─ Associates with current prompt_id (via registry)│
|
||||
│ └─ Feeds into UsageStats same as Plan A │
|
||||
│ │
|
||||
│ Coverage: 100% — catches ALL embedding loads │
|
||||
│ │
|
||||
│ Cost: Requires patching into ComfyUI internals │
|
||||
│ (sd1_clip.py, sdxl_clip.py, some text_encoders) │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Plan B Detail — Monkey-patch `load_embed`
|
||||
|
||||
### Target Function
|
||||
|
||||
**`comfy.sd1_clip.load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None)`**
|
||||
at line 415 of `sd1_clip.py`.
|
||||
|
||||
This is the **single choke point** for all embedding file loads in ComfyUI. Every
|
||||
CLIP variant (SD1, SDXL, SD3, Flux) calls this same function.
|
||||
|
||||
### Implementation Sketch
|
||||
|
||||
```python
|
||||
# In metadata_collector/metadata_hook.py (or a new module)
|
||||
import comfy.sd1_clip as sd1_clip
|
||||
|
||||
_original_load_embed = sd1_clip.load_embed
|
||||
|
||||
def _patched_load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||
result = _original_load_embed(
|
||||
embedding_name, embedding_directory, embedding_size, embed_key
|
||||
)
|
||||
if result is not None:
|
||||
_record_embedding_usage(embedding_name)
|
||||
return result
|
||||
|
||||
sd1_clip.load_embed = _patched_load_embed
|
||||
```
|
||||
|
||||
### Prompt ID Association
|
||||
|
||||
The challenge is associating the `load_embed` call with the current `prompt_id`.
|
||||
Options:
|
||||
|
||||
1. **Thread-local / contextvar**: Store current `prompt_id` in a `contextvars.ContextVar`
|
||||
that the metadata_collector sets at the start of each prompt execution.
|
||||
|
||||
2. **MetadataRegistry singleton**: The MetadataRegistry already has `current_prompt_id`.
|
||||
The patch can read it directly since both run in the same thread.
|
||||
|
||||
3. **Lazy aggregation**: Instead of associating with prompt_id at load time, collect
|
||||
all loaded embedding names in a global set during execution, then flush to
|
||||
UsageStats after the prompt completes.
|
||||
|
||||
### Files to Patch
|
||||
|
||||
| File | Function | Coverage |
|
||||
|------|----------|----------|
|
||||
| `comfy/sd1_clip.py:415` | `load_embed()` | Primary — SD1.x, SDXL, SD3, Flux |
|
||||
| `comfy/sdxl_clip.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — |
|
||||
| `comfy/text_encoders/sd3_clip.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — |
|
||||
| `comfy/text_encoders/flux.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — |
|
||||
|
||||
The SD1 tokenizer is the base class for all CLIP variants' tokenizers, so patching
|
||||
`load_embed` covers them all.
|
||||
|
||||
### Edge Cases
|
||||
|
||||
| Edge Case | Plan A | Plan B |
|
||||
|-----------|--------|--------|
|
||||
| `embedding:name` in CLIPTextEncode | ✅ | ✅ |
|
||||
| `embedding:name` in CLIPTextEncodeFlux | ✅ | ✅ |
|
||||
| `embedding:name` in PromptLM (LoRA Manager) | ✅ | ✅ |
|
||||
| `embedding:name` in WAS_Text_to_Conditioning | ✅ | ✅ |
|
||||
| Custom node that loads embedding programmatically | ❌ | ✅ |
|
||||
| Embedding loaded multiple times in same prompt | ✅ (dedup via set) | ✅ (dedup via set) |
|
||||
| Embedding file not found | N/A | ✅ (can log) |
|
||||
| Embedding dimension mismatch | N/A | ✅ (can log) |
|
||||
| Text encoder with non-standard tokenizer (LLaMA, T5...) | Partial | ✅ (if it calls load_embed) |
|
||||
|
||||
## Migration Path: Standalone → Hybrid
|
||||
|
||||
### Phase 1 — Plan A (当前状态)
|
||||
- Prompt text parsing only
|
||||
- No monkey-patching required
|
||||
- Covers all standard workflows
|
||||
|
||||
### Phase 2 — Enable Plan B (未来工作)
|
||||
1. Add monkey-patch of `load_embed` in `metadata_collector/metadata_hook.py` (alongside
|
||||
the existing `_map_node_over_list` hook)
|
||||
2. Collect loaded embedding names in a `set()` on the registry
|
||||
3. In `UsageStats._process_embeddings()`, merge the Plan A results (from prompt text)
|
||||
with the Plan B results (from the patch)
|
||||
4. Add `prompt_data` field on MetadataRegistry to store loaded embeddings per prompt
|
||||
|
||||
### Deduplication
|
||||
|
||||
```python
|
||||
# Merge Plan A + Plan B results in _process_embeddings
|
||||
plan_a_names = extract_from_prompt_texts(prompts_data)
|
||||
plan_b_names = registry.get_loaded_embeddings(prompt_id)
|
||||
|
||||
all_names = plan_a_names | plan_b_names
|
||||
```
|
||||
|
||||
## Testing the Hybrid
|
||||
|
||||
| Scenario | What to verify |
|
||||
|----------|---------------|
|
||||
| Standard `embedding:name` in prompt | Plan A captures it |
|
||||
| Embedding loaded by custom node script | Plan B captures it |
|
||||
| Both paths fire for same embedding | No double-counting (dedup) |
|
||||
| Embedding name resolves to hash | EmbeddingScanner.get_hash_by_filename works |
|
||||
| No embedding scanner available | Graceful skip, no crash |
|
||||
| Missing embedding file | Plan B logs warning, Plan A skips gracefully |
|
||||
| Empty prompt | No crash, no entries |
|
||||
| Standalone mode | Both plans disabled gracefully |
|
||||
|
||||
## Key Files Reference
|
||||
|
||||
| File | Role |
|
||||
|------|------|
|
||||
| `py/utils/usage_stats.py` | Core — `_process_embeddings()` for Plan A |
|
||||
| `py/metadata_collector/constants.py` | `EMBEDDINGS` category constant |
|
||||
| `py/metadata_collector/metadata_hook.py` | Future — monkey-patch for Plan B |
|
||||
| `py/services/embedding_scanner.py` | Hash resolution service |
|
||||
| `py/routes/stats_routes.py` | Already handles `usage_data.get('embeddings', {})` |
|
||||
| `comfy/sd1_clip.py` (ComfyUI) | `load_embed()` — Plan B target |
|
||||
464
.specs/metadata.schema.json
Normal file
464
.specs/metadata.schema.json
Normal file
@@ -0,0 +1,464 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "https://github.com/willmiao/ComfyUI-Lora-Manager/.specs/metadata.schema.json",
|
||||
"title": "ComfyUI LoRa Manager Model Metadata",
|
||||
"description": "Schema for .metadata.json sidecar files used by ComfyUI LoRa Manager",
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
{
|
||||
"title": "LoRA Model Metadata",
|
||||
"properties": {
|
||||
"file_name": {
|
||||
"type": "string",
|
||||
"description": "Filename without extension"
|
||||
},
|
||||
"model_name": {
|
||||
"type": "string",
|
||||
"description": "Display name of the model"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Full absolute path to the model file"
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"description": "File size in bytes at time of import/download"
|
||||
},
|
||||
"modified": {
|
||||
"type": "number",
|
||||
"description": "Unix timestamp when model was imported/added (Date Added)"
|
||||
},
|
||||
"sha256": {
|
||||
"type": "string",
|
||||
"pattern": "^[a-f0-9]{64}$",
|
||||
"description": "SHA256 hash of the model file (lowercase)"
|
||||
},
|
||||
"base_model": {
|
||||
"type": "string",
|
||||
"description": "Base model type (SD1.5, SD2.1, SDXL, SD3, Flux, Unknown, etc.)"
|
||||
},
|
||||
"preview_url": {
|
||||
"type": "string",
|
||||
"description": "Path to preview image file"
|
||||
},
|
||||
"preview_nsfw_level": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "NSFW level using bitmask values: 0 (none), 1 (PG), 2 (PG13), 4 (R), 8 (X), 16 (XXX), 32 (Blocked)"
|
||||
},
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "User-defined notes"
|
||||
},
|
||||
"from_civitai": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"description": "Whether the model originated from Civitai"
|
||||
},
|
||||
"civitai": {
|
||||
"$ref": "#/definitions/civitaiObject"
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": [],
|
||||
"description": "Model tags"
|
||||
},
|
||||
"modelDescription": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Full model description"
|
||||
},
|
||||
"civitai_deleted": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model was deleted from Civitai"
|
||||
},
|
||||
"favorite": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether the model is marked as favorite"
|
||||
},
|
||||
"exclude": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether to exclude from cache/scanning"
|
||||
},
|
||||
"db_checked": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Whether checked against archive database"
|
||||
},
|
||||
"skip_metadata_refresh": {
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Skip this model during bulk metadata refresh"
|
||||
},
|
||||
"metadata_source": {
|
||||
"type": ["string", "null"],
|
||||
"enum": ["civitai_api", "civarchive", "archive_db", null],
|
||||
"default": null,
|
||||
"description": "Last provider that supplied metadata"
|
||||
},
|
||||
"last_checked_at": {
|
||||
"type": "number",
|
||||
"default": 0,
|
||||
"description": "Unix timestamp of last metadata check"
|
||||
},
|
||||
"hash_status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "calculating", "completed", "failed"],
|
||||
"default": "completed",
|
||||
"description": "Hash calculation status"
|
||||
},
|
||||
"usage_tips": {
|
||||
"type": "string",
|
||||
"default": "{}",
|
||||
"description": "JSON string containing recommended usage parameters (LoRA only)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"file_name",
|
||||
"model_name",
|
||||
"file_path",
|
||||
"size",
|
||||
"modified",
|
||||
"sha256",
|
||||
"base_model"
|
||||
],
|
||||
"additionalProperties": true
|
||||
},
|
||||
{
|
||||
"title": "Checkpoint Model Metadata",
|
||||
"properties": {
|
||||
"file_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
},
|
||||
"modified": {
|
||||
"type": "number"
|
||||
},
|
||||
"sha256": {
|
||||
"type": "string",
|
||||
"pattern": "^[a-f0-9]{64}$"
|
||||
},
|
||||
"base_model": {
|
||||
"type": "string"
|
||||
},
|
||||
"preview_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"preview_nsfw_level": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 3,
|
||||
"default": 0
|
||||
},
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"default": ""
|
||||
},
|
||||
"from_civitai": {
|
||||
"type": "boolean",
|
||||
"default": true
|
||||
},
|
||||
"civitai": {
|
||||
"$ref": "#/definitions/civitaiObject"
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": []
|
||||
},
|
||||
"modelDescription": {
|
||||
"type": "string",
|
||||
"default": ""
|
||||
},
|
||||
"civitai_deleted": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"favorite": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"exclude": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"db_checked": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"skip_metadata_refresh": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"metadata_source": {
|
||||
"type": ["string", "null"],
|
||||
"enum": ["civitai_api", "civarchive", "archive_db", null],
|
||||
"default": null
|
||||
},
|
||||
"last_checked_at": {
|
||||
"type": "number",
|
||||
"default": 0
|
||||
},
|
||||
"hash_status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "calculating", "completed", "failed"],
|
||||
"default": "completed"
|
||||
},
|
||||
"sub_type": {
|
||||
"type": "string",
|
||||
"default": "checkpoint",
|
||||
"description": "Model sub-type (checkpoint, diffusion_model, etc.)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"file_name",
|
||||
"model_name",
|
||||
"file_path",
|
||||
"size",
|
||||
"modified",
|
||||
"sha256",
|
||||
"base_model"
|
||||
],
|
||||
"additionalProperties": true
|
||||
},
|
||||
{
|
||||
"title": "Embedding Model Metadata",
|
||||
"properties": {
|
||||
"file_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"model_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
},
|
||||
"modified": {
|
||||
"type": "number"
|
||||
},
|
||||
"sha256": {
|
||||
"type": "string",
|
||||
"pattern": "^[a-f0-9]{64}$"
|
||||
},
|
||||
"base_model": {
|
||||
"type": "string"
|
||||
},
|
||||
"preview_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"preview_nsfw_level": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 3,
|
||||
"default": 0
|
||||
},
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"default": ""
|
||||
},
|
||||
"from_civitai": {
|
||||
"type": "boolean",
|
||||
"default": true
|
||||
},
|
||||
"civitai": {
|
||||
"$ref": "#/definitions/civitaiObject"
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"default": []
|
||||
},
|
||||
"modelDescription": {
|
||||
"type": "string",
|
||||
"default": ""
|
||||
},
|
||||
"civitai_deleted": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"favorite": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"exclude": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"db_checked": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"skip_metadata_refresh": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
},
|
||||
"metadata_source": {
|
||||
"type": ["string", "null"],
|
||||
"enum": ["civitai_api", "civarchive", "archive_db", null],
|
||||
"default": null
|
||||
},
|
||||
"last_checked_at": {
|
||||
"type": "number",
|
||||
"default": 0
|
||||
},
|
||||
"hash_status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "calculating", "completed", "failed"],
|
||||
"default": "completed"
|
||||
},
|
||||
"sub_type": {
|
||||
"type": "string",
|
||||
"default": "embedding",
|
||||
"description": "Model sub-type"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"file_name",
|
||||
"model_name",
|
||||
"file_path",
|
||||
"size",
|
||||
"modified",
|
||||
"sha256",
|
||||
"base_model"
|
||||
],
|
||||
"additionalProperties": true
|
||||
}
|
||||
],
|
||||
"definitions": {
|
||||
"civitaiObject": {
|
||||
"type": "object",
|
||||
"default": {},
|
||||
"description": "Civitai/CivArchive API data and user-defined fields",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "integer",
|
||||
"description": "Version ID from Civitai"
|
||||
},
|
||||
"modelId": {
|
||||
"type": "integer",
|
||||
"description": "Model ID from Civitai"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Version name"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Version description"
|
||||
},
|
||||
"baseModel": {
|
||||
"type": "string",
|
||||
"description": "Base model type from Civitai"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Model type (checkpoint, embedding, etc.)"
|
||||
},
|
||||
"trainedWords": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Trigger words for the model (from API or user-defined)"
|
||||
},
|
||||
"customImages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object"
|
||||
},
|
||||
"description": "Custom example images added by user"
|
||||
},
|
||||
"model": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"images": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"creator": {
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"additionalProperties": true
|
||||
},
|
||||
"usageTips": {
|
||||
"type": "object",
|
||||
"description": "Structure for usage_tips JSON string (LoRA models)",
|
||||
"properties": {
|
||||
"strength_min": {
|
||||
"type": "number",
|
||||
"description": "Minimum recommended model strength"
|
||||
},
|
||||
"strength_max": {
|
||||
"type": "number",
|
||||
"description": "Maximum recommended model strength"
|
||||
},
|
||||
"strength_range": {
|
||||
"type": "string",
|
||||
"description": "Human-readable strength range"
|
||||
},
|
||||
"strength": {
|
||||
"type": "number",
|
||||
"description": "Single recommended strength value"
|
||||
},
|
||||
"clip_strength": {
|
||||
"type": "number",
|
||||
"description": "Recommended CLIP/embedding strength"
|
||||
},
|
||||
"clip_skip": {
|
||||
"type": "integer",
|
||||
"description": "Recommended CLIP skip value"
|
||||
}
|
||||
},
|
||||
"additionalProperties": true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -135,9 +135,16 @@ npm run test:coverage # Generate coverage report
|
||||
- ALWAYS use English for comments (per copilot-instructions.md)
|
||||
- Dual mode: ComfyUI plugin (folder_paths) vs standalone (settings.json)
|
||||
- Detection: `os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"`
|
||||
- Run `python scripts/sync_translation_keys.py` after UI string updates
|
||||
- Run `python scripts/sync_translation_keys.py` after adding UI strings to `locales/en.json`
|
||||
- Symlinks require normalized paths
|
||||
|
||||
## Git / Commit Messages
|
||||
|
||||
- Follow the style of recent repository commits when writing commit messages
|
||||
- Prefer the repo's existing `feat(...)`, `fix(...)`, `chore:` style where applicable
|
||||
- If the user has provided a GitHub issue link or issue ID for the task, mention that issue in the commit message, for example `(#871)`
|
||||
- When unrelated local changes exist, stage and commit only the files relevant to the requested task
|
||||
|
||||
## Frontend UI Architecture
|
||||
|
||||
### 1. Standalone Web UI
|
||||
|
||||
25
__init__.py
25
__init__.py
@@ -1,10 +1,13 @@
|
||||
try: # pragma: no cover - import fallback for pytest collection
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
||||
from .py.nodes.checkpoint_loader import CheckpointLoaderLM
|
||||
from .py.nodes.unet_loader import UNETLoaderLM
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
|
||||
from .py.nodes.prompt import PromptLM
|
||||
from .py.nodes.text import TextLM
|
||||
from .py.nodes.lora_stacker import LoraStackerLM
|
||||
from .py.nodes.lora_stack_combiner import LoraStackCombinerLM
|
||||
from .py.nodes.save_image import SaveImageLM
|
||||
from .py.nodes.debug_metadata import DebugMetadataLM
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelectLM
|
||||
@@ -27,16 +30,19 @@ except (
|
||||
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
|
||||
TextLM = importlib.import_module("py.nodes.text").TextLM
|
||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||
LoraLoaderLM = importlib.import_module(
|
||||
"py.nodes.lora_loader"
|
||||
).LoraLoaderLM
|
||||
LoraTextLoaderLM = importlib.import_module(
|
||||
"py.nodes.lora_loader"
|
||||
).LoraTextLoaderLM
|
||||
LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM
|
||||
LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM
|
||||
CheckpointLoaderLM = importlib.import_module(
|
||||
"py.nodes.checkpoint_loader"
|
||||
).CheckpointLoaderLM
|
||||
UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM
|
||||
TriggerWordToggleLM = importlib.import_module(
|
||||
"py.nodes.trigger_word_toggle"
|
||||
).TriggerWordToggleLM
|
||||
LoraStackerLM = importlib.import_module("py.nodes.lora_stacker").LoraStackerLM
|
||||
LoraStackCombinerLM = importlib.import_module(
|
||||
"py.nodes.lora_stack_combiner"
|
||||
).LoraStackCombinerLM
|
||||
SaveImageLM = importlib.import_module("py.nodes.save_image").SaveImageLM
|
||||
DebugMetadataLM = importlib.import_module("py.nodes.debug_metadata").DebugMetadataLM
|
||||
WanVideoLoraSelectLM = importlib.import_module(
|
||||
@@ -49,9 +55,7 @@ except (
|
||||
LoraRandomizerLM = importlib.import_module(
|
||||
"py.nodes.lora_randomizer"
|
||||
).LoraRandomizerLM
|
||||
LoraCyclerLM = importlib.import_module(
|
||||
"py.nodes.lora_cycler"
|
||||
).LoraCyclerLM
|
||||
LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM
|
||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -59,8 +63,11 @@ NODE_CLASS_MAPPINGS = {
|
||||
TextLM.NAME: TextLM,
|
||||
LoraLoaderLM.NAME: LoraLoaderLM,
|
||||
LoraTextLoaderLM.NAME: LoraTextLoaderLM,
|
||||
CheckpointLoaderLM.NAME: CheckpointLoaderLM,
|
||||
UNETLoaderLM.NAME: UNETLoaderLM,
|
||||
TriggerWordToggleLM.NAME: TriggerWordToggleLM,
|
||||
LoraStackerLM.NAME: LoraStackerLM,
|
||||
LoraStackCombinerLM.NAME: LoraStackCombinerLM,
|
||||
SaveImageLM.NAME: SaveImageLM,
|
||||
DebugMetadataLM.NAME: DebugMetadataLM,
|
||||
WanVideoLoraSelectLM.NAME: WanVideoLoraSelectLM,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,183 +0,0 @@
|
||||
## Overview
|
||||
|
||||
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com). With this extension, you can:
|
||||
|
||||
✅ Instantly see which models are already present in your local library
|
||||
✅ Download new models with a single click
|
||||
✅ Manage downloads efficiently with queue and parallel download support
|
||||
✅ Keep your downloaded models automatically organized according to your custom settings
|
||||
|
||||

|
||||
|
||||
**Update:** It now also supports browsing on [CivArchive](https://civarchive.com/) (formerly CivitaiArchive).
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why Supporter Access?
|
||||
|
||||
LoRA Manager is built with love for the Stable Diffusion and ComfyUI communities. Your support makes it possible for me to keep improving and maintaining the tool full-time.
|
||||
|
||||
Supporter-exclusive features help ensure the long-term sustainability of LoRA Manager, allowing continuous updates, new features, and better performance for everyone.
|
||||
|
||||
Every contribution directly fuels development and keeps the core LoRA Manager free and open-source. In addition to monthly supporters, one-time donation supporters will also receive a license key, with the duration scaling according to the contribution amount. Thank you for helping keep this project alive and growing. ❤️
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
### Supported Browsers & Installation Methods
|
||||
|
||||
| Browser | Installation Method |
|
||||
|--------------------|-------------------------------------------------------------------------------------|
|
||||
| **Google Chrome** | [Chrome Web Store link](https://chromewebstore.google.com/detail/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) |
|
||||
| **Microsoft Edge** | Install via Chrome Web Store (compatible) |
|
||||
| **Brave Browser** | Install via Chrome Web Store (compatible) |
|
||||
| **Opera** | Install via Chrome Web Store (compatible) |
|
||||
| **Firefox** | <div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div> |
|
||||
|
||||
For non-Chrome browsers (e.g., Microsoft Edge), you can typically install extensions from the Chrome Web Store by following these steps: open the extension’s Chrome Web Store page, click 'Get extension', then click 'Allow' when prompted to enable installations from other stores, and finally click 'Add extension' to complete the installation.
|
||||
|
||||
---
|
||||
|
||||
## Privacy & Security
|
||||
|
||||
I understand concerns around browser extensions and privacy, and I want to be fully transparent about how the **LM Civitai Extension** works:
|
||||
|
||||
- **Reviewed and Verified**
|
||||
This extension has been **manually reviewed and approved by the Chrome Web Store**. The Firefox version uses the **exact same code** (only the packaging format differs) and has passed **Mozilla’s Add-on review**.
|
||||
|
||||
- **Minimal Network Access**
|
||||
The only external server this extension connects to is:
|
||||
**`https://willmiao.shop`** — used solely for **license validation**.
|
||||
|
||||
It does **not collect, transmit, or store any personal or usage data**.
|
||||
No browsing history, no user IDs, no analytics, no hidden trackers.
|
||||
|
||||
- **Local-Only Model Detection**
|
||||
Model detection and LoRA Manager communication all happen **locally** within your browser, directly interacting with your local LoRA Manager backend.
|
||||
|
||||
I value your trust and are committed to keeping your local setup private and secure. If you have any questions, feel free to reach out!
|
||||
|
||||
---
|
||||
|
||||
## How to Use
|
||||
|
||||
After installing the extension, you'll automatically receive a **7-day trial** to explore all features.
|
||||
|
||||
When the extension is correctly installed and your license is valid:
|
||||
|
||||
- Open **Civitai**, and you'll see visual indicators added by the extension on model cards, showing:
|
||||
- ✅ Models already present in your local library
|
||||
- ⬇️ A download button for models not in your library
|
||||
|
||||
Clicking the download button adds the corresponding model version to the download queue, waiting to be downloaded. You can set up to **5 models to download simultaneously**.
|
||||
|
||||
### Visual Indicators Appear On:
|
||||
|
||||
- **Home Page** — Featured models
|
||||
- **Models Page**
|
||||
- **Creator Profiles** — If the creator has set their models to be visible
|
||||
- **Recommended Resources** — On individual model pages
|
||||
|
||||
### Version Buttons on Model Pages
|
||||
|
||||
On a specific model page, visual indicators also appear on version buttons, showing which versions are already in your local library.
|
||||
|
||||
**Starting from v0.4.8**, model pages use a dedicated download button for better compatibility. When switching to a specific version by clicking a version button:
|
||||
|
||||
- The new **dedicated download button** directly triggers download via **LoRA Manager**
|
||||
- The **original download button** remains unchanged for standard browser downloads
|
||||
|
||||

|
||||
|
||||
### Hide Models Already in Library (Beta)
|
||||
|
||||
**New in v0.4.8**: A new **Hide models already in library (Beta)** option makes it easier to focus on models you haven't added yet. It can be enabled from Settings, or toggled quickly using **Ctrl + Shift + H** (macOS: **Command + Shift + H**).
|
||||
|
||||
### Resources on Image Pages — now shows in-library indicators for image resources plus one-click recipe import
|
||||
|
||||
- **One-Click Import Civitai Image as Recipe** — Import any Civitai image as a recipe with a single click in the Resources Used panel.
|
||||
- **Auto-Queue Missing Assets** — In Settings you can decide if LoRAs or checkpoints referenced by that image should automatically be added to your download queue.
|
||||
- **More Accurate Metadata** — Importing directly from the page is faster than copying inside LM and keeps on-site tags and other metadata perfectly aligned.
|
||||
|
||||

|
||||
|
||||
[](https://github.com/user-attachments/assets/41fd4240-c949-4f83-bde7-8f3124c09494)
|
||||
|
||||
---
|
||||
|
||||
## Model Download Location & LoRA Manager Settings
|
||||
|
||||
To use the **one-click download function**, you must first set:
|
||||
|
||||
- Your **Default LoRAs Root**
|
||||
- Your **Default Checkpoints Root**
|
||||
|
||||
These are set within LoRA Manager's settings.
|
||||
|
||||
When everything is configured, downloaded model files will be placed in:
|
||||
|
||||
`<Default_Models_Root>/<Base_Model_of_the_Model>/<First_Tag_of_the_Model>`
|
||||
|
||||
|
||||
### Update: Default Path Customization (2025-07-21)
|
||||
|
||||
A new setting to customize the default download path has been added in the nightly version. You can now personalize where models are saved when downloading via the LM Civitai Extension.
|
||||
|
||||

|
||||
|
||||
The previous YAML path mapping file will be deprecated—settings will now be unified in settings.json to simplify configuration.
|
||||
|
||||
---
|
||||
|
||||
## Backend Port Configuration
|
||||
|
||||
If your **ComfyUI** or **LoRA Manager** backend is running on a port **other than the default 8188**, you must configure the backend port in the extension's settings.
|
||||
|
||||
After correctly setting and saving the port, you'll see in the extension's header area:
|
||||
- A **Healthy** status with the tooltip: `Connected to LoRA Manager on port xxxx`
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Connecting to a Remote LoRA Manager
|
||||
|
||||
If your LoRA Manager is running on another computer, you can still connect from your browser using port forwarding.
|
||||
|
||||
> **Why can't you set a remote IP directly?**
|
||||
>
|
||||
> For privacy and security, the extension only requests access to `http://127.0.0.1/*`. Supporting remote IPs would require much broader permissions, which may be rejected by browser stores and could raise user concerns.
|
||||
|
||||
**Solution: Port Forwarding with `socat`**
|
||||
|
||||
On your browser computer, run:
|
||||
|
||||
`socat TCP-LISTEN:8188,bind=127.0.0.1,fork TCP:REMOTE.IP.ADDRESS.HERE:8188`
|
||||
|
||||
- Replace `REMOTE.IP.ADDRESS.HERE` with the IP of the machine running LoRA Manager.
|
||||
- Adjust the port if needed.
|
||||
|
||||
This lets the extension connect to `127.0.0.1:8188` as usual, with traffic forwarded to your remote server.
|
||||
|
||||
_Thanks to user **Temikus** for sharing this solution!_
|
||||
|
||||
---
|
||||
|
||||
## Roadmap
|
||||
|
||||
The extension will evolve alongside **LoRA Manager** improvements. Planned features include:
|
||||
|
||||
- [x] Support for **additional model types** (e.g., embeddings)
|
||||
- [x] One-click **Recipe Import**
|
||||
- [x] Display of in-library status for all resources in the **Resources Used** section of the image page
|
||||
- [x] One-click **Auto-organize Models**
|
||||
- [x] **Hide models already in library (Beta)** - Focus on models you haven't added yet
|
||||
|
||||
**Stay tuned — and thank you for your support!**
|
||||
|
||||
---
|
||||
208
docs/agent_skills.md
Normal file
208
docs/agent_skills.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# Agent Skills System
|
||||
|
||||
The LoRA Manager agent skills system enables LLM-powered metadata enrichment and other AI-driven tasks. Users configure their own LLM provider (BYOK), and skills are executed through right-click context menu actions.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────┐
|
||||
│ LoRA Manager Backend │
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌────────────────┐ │
|
||||
│ │ LLMService │───▶│ LLM Provider │ │
|
||||
│ │ (BYOK config, │◀───│ (OpenAI/Ollama │ │
|
||||
│ │ API calls) │ │ /custom) │ │
|
||||
│ └───────┬───────┘ └────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────▼───────────────────────┐ │
|
||||
│ │ AgentService │ │
|
||||
│ │ (orchestration: validate │ │
|
||||
│ │ → LLM call → post-process │ │
|
||||
│ │ → WebSocket broadcast) │ │
|
||||
│ └───────┬───────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────▼───────────────────────┐ │
|
||||
│ │ SkillRegistry │ │
|
||||
│ │ ┌─────────────────────────┐ │ │
|
||||
│ │ │ enrich_hf_metadata: │ │ │
|
||||
│ │ │ - skill.yaml │ │ │
|
||||
│ │ │ - prompt.md │ │ │
|
||||
│ │ │ - handler.py │ │ │
|
||||
│ │ └─────────────────────────┘ │ │
|
||||
│ └───────────────────────────────┘ │
|
||||
└──────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Key Design Principle
|
||||
|
||||
**Skills define *what* to do (prompt + post-processing). The AgentService handles *how* (LLM calls, validation, progress).**
|
||||
|
||||
Skills never call the LLM directly. This keeps BYOK configuration centralized and provider-agnostic.
|
||||
|
||||
## BYOK Configuration
|
||||
|
||||
Users configure their LLM provider in **Settings → AI Provider**:
|
||||
|
||||
| Setting | Description | Example |
|
||||
|---|---|---|
|
||||
| `llm_provider` | Provider type | `openai`, `ollama`, or `custom` |
|
||||
| `llm_api_key` | API key (not needed for local Ollama) | `sk-...` |
|
||||
| `llm_api_base` | Custom API base URL (empty = provider default) | `https://api.openai.com/v1` |
|
||||
| `llm_model` | Model name | `gpt-4o-mini` |
|
||||
|
||||
Environment variable overrides: `LLM_API_KEY`, `LLM_MODEL`, `LLM_API_BASE`, `LLM_PROVIDER`.
|
||||
|
||||
### Supported Providers
|
||||
|
||||
- **OpenAI**: Uses `https://api.openai.com/v1` by default
|
||||
- **Ollama** (local): Uses `http://localhost:11434/v1`, no API key required
|
||||
- **Custom**: Any OpenAI-compatible endpoint (vLLM, LM Studio, etc.) — set `llm_api_base` explicitly
|
||||
|
||||
## Available Skills
|
||||
|
||||
### enrich_hf_metadata
|
||||
|
||||
Enriches HuggingFace-downloaded models with metadata extracted by an LLM from the HF model card.
|
||||
|
||||
**Entry point**: Right-click context menu → "Enrich Metadata (Agent)"
|
||||
|
||||
**What it does**:
|
||||
1. Reads the model's `.metadata.json` to get the `hf_url`
|
||||
2. Fetches the README.md from the HuggingFace repository
|
||||
3. Sends the README + local metadata to the LLM for structured extraction
|
||||
4. Writes extracted fields to `.metadata.json`:
|
||||
- `base_model` — only if current value is empty
|
||||
- `trainedWords` — trigger words (LoRA only, if none exist)
|
||||
- `modelDescription` — concise summary (if none exists)
|
||||
- `tags` — merged with existing tags, deduplicated
|
||||
- `metadata_source` — audit trail: `agent:enrich_hf_metadata`
|
||||
- `llm_enriched_at` — ISO timestamp
|
||||
5. Downloads and optimizes preview image (if LLM found one in the README)
|
||||
6. Updates the scanner cache
|
||||
7. Broadcasts WebSocket progress events
|
||||
|
||||
**Model types**: LoRA, Checkpoint, Embedding
|
||||
|
||||
## Adding a New Skill
|
||||
|
||||
### 1. Create the skill directory
|
||||
|
||||
```
|
||||
py/services/agent/skills/<skill_name>/
|
||||
├── skill.yaml # Skill metadata and schemas
|
||||
├── prompt.md # LLM prompt template
|
||||
└── handler.py # Pre-processing and post-processing
|
||||
```
|
||||
|
||||
### 2. Write skill.yaml
|
||||
|
||||
```yaml
|
||||
name: my_skill
|
||||
title: "My Skill"
|
||||
description: "What this skill does"
|
||||
llm_required: true
|
||||
model_type_filter: ["lora"] # or null for all types
|
||||
input_schema:
|
||||
type: object
|
||||
properties:
|
||||
model_paths:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- model_paths
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
# ... JSON schema for LLM output
|
||||
permissions:
|
||||
write_metadata: true
|
||||
write_previews: false
|
||||
network_domains:
|
||||
- "example.com"
|
||||
```
|
||||
|
||||
### 3. Write prompt.md
|
||||
|
||||
Use `{{variable}}` placeholders that will be replaced with data from the `prepare` function:
|
||||
|
||||
```markdown
|
||||
You are an expert assistant...
|
||||
|
||||
Model URL: {{hf_url}}
|
||||
README content:
|
||||
{{readme_content}}
|
||||
|
||||
Current metadata:
|
||||
{{current_metadata}}
|
||||
```
|
||||
|
||||
### 4. Write handler.py
|
||||
|
||||
```python
|
||||
async def prepare(model_path: str, input_data: dict) -> dict:
|
||||
"""Gather context for the LLM prompt. Returns variables for template rendering."""
|
||||
return {
|
||||
"model_path": model_path,
|
||||
# ... other variables used in prompt.md
|
||||
}
|
||||
|
||||
async def post_process(context) -> dict:
|
||||
"""Apply the LLM-extracted data to the model."""
|
||||
llm_response = context.llm_response
|
||||
# ... write metadata, download previews, update cache
|
||||
return {
|
||||
"success": True,
|
||||
"updated_fields": ["base_model", "tags"],
|
||||
"errors": [],
|
||||
}
|
||||
```
|
||||
|
||||
**Important**: Use absolute imports (`from py.utils.metadata_manager import MetadataManager`) because skills are loaded via `importlib.util.spec_from_file_location`, which doesn't support relative imports.
|
||||
|
||||
### 5. Test
|
||||
|
||||
The skill is automatically discovered by `SkillRegistry` on startup. Test with:
|
||||
|
||||
```python
|
||||
pytest tests/services/test_agent_service.py
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|---|---|---|
|
||||
| GET | `/api/lm/agent/skills` | List available skills |
|
||||
| POST | `/api/lm/agent/execute/{skill_name}` | Execute a skill (body: `{"model_paths": [...]}`) |
|
||||
| POST | `/api/lm/agent/cancel` | Cancel running skill (stub) |
|
||||
|
||||
## WebSocket Events
|
||||
|
||||
| Type | When | Key fields |
|
||||
|---|---|---|
|
||||
| `agent_progress` | Skill started/processing | `skill`, `status`, `total`, `processed`, `success`, `current_path` |
|
||||
| `agent_progress` | Skill completed | `skill`, `status`, `updated_models`, `errors`, `summary` |
|
||||
| `agent_progress` | Skill error | `skill`, `status`, `error` |
|
||||
|
||||
## Security Model
|
||||
|
||||
Skills declare permissions in `skill.yaml`:
|
||||
- `write_metadata` — can write `.metadata.json` files
|
||||
- `write_previews` — can download/replace preview images
|
||||
- `network_domains` — allowed domains for HTTP requests
|
||||
|
||||
These are declarative constraints checked by `AgentService`. They are defense-in-depth, not a sandbox — the Python process can technically do anything, but the contract is clear and auditable.
|
||||
|
||||
## File Locations
|
||||
|
||||
| Component | Path |
|
||||
|---|---|
|
||||
| LLMService | `py/services/llm_service.py` |
|
||||
| AgentService | `py/services/agent/agent_service.py` |
|
||||
| SkillRegistry | `py/services/agent/skill_registry.py` |
|
||||
| SkillDefinition | `py/services/agent/skill_definition.py` |
|
||||
| Skills directory | `py/services/agent/skills/` |
|
||||
| Route handlers | `py/routes/handlers/agent_handlers.py` |
|
||||
| Frontend manager | `static/js/managers/AgentManager.js` |
|
||||
| Settings UI | `templates/components/modals/settings_modal.html` |
|
||||
| Context menu | `templates/components/context_menu.html` |
|
||||
363
docs/metadata-json-schema.md
Normal file
363
docs/metadata-json-schema.md
Normal file
@@ -0,0 +1,363 @@
|
||||
# metadata.json Schema Documentation
|
||||
|
||||
This document defines the complete schema for `.metadata.json` files used by Lora Manager. These sidecar files store model metadata alongside model files (LoRA, Checkpoint, Embedding).
|
||||
|
||||
## Overview
|
||||
|
||||
- **File naming**: `<model_name>.metadata.json` (e.g., `my_lora.safetensors` → `my_lora.metadata.json`)
|
||||
- **Format**: JSON with UTF-8 encoding
|
||||
- **Purpose**: Store model metadata, tags, descriptions, preview images, and Civitai/CivArchive integration data
|
||||
- **Extensibility**: Unknown fields are preserved via `_unknown_fields` mechanism for forward compatibility
|
||||
|
||||
---
|
||||
|
||||
## Base Fields (All Model Types)
|
||||
|
||||
These fields are present in all model metadata files.
|
||||
|
||||
| Field | Type | Required | Auto-Updated | Description |
|
||||
|-------|------|----------|--------------|-------------|
|
||||
| `file_name` | string | ✅ Yes | ✅ Yes | Filename without extension (e.g., `"my_lora"`) |
|
||||
| `model_name` | string | ✅ Yes | ❌ No | Display name of the model. **Default**: `file_name` if no other source |
|
||||
| `file_path` | string | ✅ Yes | ✅ Yes | Full absolute path to the model file (normalized with `/` separators) |
|
||||
| `size` | integer | ✅ Yes | ❌ No | File size in bytes. **Set at**: Initial scan or download completion. Does not change thereafter. |
|
||||
| `modified` | float | ✅ Yes | ❌ No | **Import timestamp** — Unix timestamp when the model was first imported/added to the system. Used for "Date Added" sorting. Does not change after initial creation. |
|
||||
| `sha256` | string | ⚠️ Conditional | ✅ Yes | SHA256 hash of the model file (lowercase). **LoRA**: Required. **Checkpoint**: May be empty when `hash_status="pending"` (lazy hash calculation) |
|
||||
| `base_model` | string | ❌ No | ❌ No | Base model type. **Examples**: `"SD 1.5"`, `"SDXL 1.0"`, `"SDXL Lightning"`, `"Flux.1 D"`, `"Flux.1 S"`, `"Flux.1 Krea"`, `"Illustrious"`, `"Pony"`, `"AuraFlow"`, `"Kolors"`, `"ZImageTurbo"`, `"Wan Video"`, etc. **Default**: `"Unknown"` or `""` |
|
||||
| `preview_url` | string | ❌ No | ✅ Yes | Path to preview image file |
|
||||
| `preview_nsfw_level` | integer | ❌ No | ❌ No | NSFW level using **bitmask values** from Civitai: `1` (PG), `2` (PG13), `4` (R), `8` (X), `16` (XXX), `32` (Blocked). **Default**: `0` (none) |
|
||||
| `notes` | string | ❌ No | ❌ No | User-defined notes |
|
||||
| `from_civitai` | boolean | ❌ No (default: `true`) | ❌ No | Whether the model originated from Civitai |
|
||||
| `civitai` | object | ❌ No | ⚠️ Partial | Civitai/CivArchive API data and user-defined fields |
|
||||
| `tags` | array[string] | ❌ No | ⚠️ Partial | Model tags (merged from API and user input) |
|
||||
| `modelDescription` | string | ❌ No | ⚠️ Partial | Full model description (from API or user) |
|
||||
| `civitai_deleted` | boolean | ❌ No (default: `false`) | ❌ No | Whether the model was deleted from Civitai |
|
||||
| `favorite` | boolean | ❌ No (default: `false`) | ❌ No | Whether the model is marked as favorite |
|
||||
| `exclude` | boolean | ❌ No (default: `false`) | ❌ No | Whether to exclude from cache/scanning. User can set from `false` to `true` (currently no UI to revert) |
|
||||
| `db_checked` | boolean | ❌ No (default: `false`) | ❌ No | Whether checked against archive database |
|
||||
| `skip_metadata_refresh` | boolean | ❌ No (default: `false`) | ❌ No | Skip this model during bulk metadata refresh |
|
||||
| `metadata_source` | string\|null | ❌ No | ✅ Yes | Last provider that supplied metadata (see below) |
|
||||
| `last_checked_at` | float | ❌ No (default: `0`) | ✅ Yes | Unix timestamp of last metadata check |
|
||||
| `hash_status` | string | ❌ No (default: `"completed"`) | ✅ Yes | Hash calculation status: `"pending"`, `"calculating"`, `"completed"`, `"failed"` |
|
||||
|
||||
---
|
||||
|
||||
## Model-Specific Fields
|
||||
|
||||
### LoRA Models
|
||||
|
||||
LoRA models do not have a `model_type` field in metadata.json. The type is inferred from context or `civitai.type` (e.g., `"LoRA"`, `"LoCon"`, `"DoRA"`).
|
||||
|
||||
| Field | Type | Required | Auto-Updated | Description |
|
||||
|-------|------|----------|--------------|-------------|
|
||||
| `usage_tips` | string (JSON) | ❌ No (default: `"{}"`) | ❌ No | JSON string containing recommended usage parameters |
|
||||
|
||||
**`usage_tips` JSON structure:**
|
||||
|
||||
```json
|
||||
{
|
||||
"strength_min": 0.3,
|
||||
"strength_max": 0.8,
|
||||
"strength_range": "0.3-0.8",
|
||||
"strength": 0.6,
|
||||
"clip_strength": 0.5,
|
||||
"clip_skip": 2
|
||||
}
|
||||
```
|
||||
|
||||
| Key | Type | Description |
|
||||
|-----|------|-------------|
|
||||
| `strength_min` | number | Minimum recommended model strength |
|
||||
| `strength_max` | number | Maximum recommended model strength |
|
||||
| `strength_range` | string | Human-readable strength range |
|
||||
| `strength` | number | Single recommended strength value |
|
||||
| `clip_strength` | number | Recommended CLIP/embedding strength |
|
||||
| `clip_skip` | integer | Recommended CLIP skip value |
|
||||
|
||||
---
|
||||
|
||||
### Checkpoint Models
|
||||
|
||||
| Field | Type | Required | Auto-Updated | Description |
|
||||
|-------|------|----------|--------------|-------------|
|
||||
| `model_type` | string | ❌ No (default: `"checkpoint"`) | ❌ No | Model type: `"checkpoint"`, `"diffusion_model"` |
|
||||
|
||||
---
|
||||
|
||||
### Embedding Models
|
||||
|
||||
| Field | Type | Required | Auto-Updated | Description |
|
||||
|-------|------|----------|--------------|-------------|
|
||||
| `model_type` | string | ❌ No (default: `"embedding"`) | ❌ No | Model type: `"embedding"` |
|
||||
|
||||
---
|
||||
|
||||
## The `civitai` Field Structure
|
||||
|
||||
The `civitai` object stores the complete Civitai/CivArchive API response. Lora Manager preserves all fields from the API for future compatibility and extracts specific fields for use in the application.
|
||||
|
||||
### Version-Level Fields (Civitai API)
|
||||
|
||||
**Fields Used by Lora Manager:**
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `id` | integer | Version ID |
|
||||
| `modelId` | integer | Parent model ID |
|
||||
| `name` | string | Version name (e.g., `"v1.0"`, `"v2.0-pruned"`) |
|
||||
| `nsfwLevel` | integer | NSFW level (bitmask: 1=PG, 2=PG13, 4=R, 8=X, 16=XXX, 32=Blocked) |
|
||||
| `baseModel` | string | Base model (e.g., `"SDXL 1.0"`, `"Flux.1 D"`, `"Illustrious"`, `"Pony"`) |
|
||||
| `trainedWords` | array[string] | **Trigger words** for the model |
|
||||
| `type` | string | Model type (`"LoRA"`, `"Checkpoint"`, `"TextualInversion"`) |
|
||||
| `earlyAccessEndsAt` | string\|null | Early access end date (used for update notifications) |
|
||||
| `description` | string | Version description (HTML) |
|
||||
| `model` | object | Parent model object (see Model-Level Fields below) |
|
||||
| `creator` | object | Creator information (see Creator Fields below) |
|
||||
| `files` | array[object] | File list with hashes, sizes, download URLs (used for metadata extraction) |
|
||||
| `images` | array[object] | Image list with metadata, prompts, NSFW levels (used for preview/examples) |
|
||||
|
||||
**Fields Stored but Not Currently Used:**
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `createdAt` | string (ISO 8601) | Creation timestamp |
|
||||
| `updatedAt` | string (ISO 8601) | Last update timestamp |
|
||||
| `status` | string | Version status (e.g., `"Published"`, `"Draft"`) |
|
||||
| `publishedAt` | string (ISO 8601) | Publication timestamp |
|
||||
| `baseModelType` | string | Base model type (e.g., `"Standard"`, `"Inpaint"`, `"Refiner"`) |
|
||||
| `earlyAccessConfig` | object | Early access configuration |
|
||||
| `uploadType` | string | Upload type (`"Created"`, `"FineTuned"`, etc.) |
|
||||
| `usageControl` | string | Usage control setting |
|
||||
| `air` | string | Artifact ID (URN format: `urn:air:sdxl:lora:civitai:122359@135867`) |
|
||||
| `stats` | object | Download count, ratings, thumbs up count |
|
||||
| `videos` | array[object] | Video list |
|
||||
| `downloadUrl` | string | Direct download URL |
|
||||
| `trainingStatus` | string\|null | Training status (for on-site training) |
|
||||
| `trainingDetails` | object\|null | Training configuration |
|
||||
|
||||
### Model-Level Fields (`civitai.model.*`)
|
||||
|
||||
**Fields Used by Lora Manager:**
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `name` | string | Model name |
|
||||
| `type` | string | Model type (`"LoRA"`, `"Checkpoint"`, `"TextualInversion"`) |
|
||||
| `description` | string | Model description (HTML, used for `modelDescription`) |
|
||||
| `tags` | array[string] | Model tags (used for `tags` field) |
|
||||
| `allowNoCredit` | boolean | License: allow use without credit |
|
||||
| `allowCommercialUse` | array[string] | License: allowed commercial uses. **Values**: `"Image"` (sell generated images), `"Video"` (sell generated videos), `"RentCivit"` (rent on Civitai), `"Rent"` (rent elsewhere) |
|
||||
| `allowDerivatives` | boolean | License: allow derivatives |
|
||||
| `allowDifferentLicense` | boolean | License: allow different license |
|
||||
|
||||
**Fields Stored but Not Currently Used:**
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `nsfw` | boolean | Model NSFW flag |
|
||||
| `poi` | boolean | Person of Interest flag |
|
||||
|
||||
### Creator Fields (`civitai.creator.*`)
|
||||
|
||||
Both fields are used by Lora Manager:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `username` | string | Creator username (used for author display and search) |
|
||||
| `image` | string | Creator avatar URL (used for display) |
|
||||
|
||||
### Model Type Field (Top-Level, Outside `civitai`)
|
||||
|
||||
| Field | Type | Values | Description |
|
||||
|-------|------|--------|-------------|
|
||||
| `model_type` | string | `"checkpoint"`, `"diffusion_model"`, `"embedding"` | Stored in metadata.json for Checkpoint and Embedding models. **Note**: LoRA models do not have this field; type is inferred from `civitai.type` or context. |
|
||||
|
||||
### User-Defined Fields (Within `civitai`)
|
||||
|
||||
For models not from Civitai or user-added data:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `trainedWords` | array[string] | **Trigger words** — manually added by user |
|
||||
| `customImages` | array[object] | Custom example images added by user |
|
||||
|
||||
### customImages Structure
|
||||
|
||||
Each custom image entry has the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"url": "",
|
||||
"id": "short_id",
|
||||
"nsfwLevel": 0,
|
||||
"width": 832,
|
||||
"height": 1216,
|
||||
"type": "image",
|
||||
"meta": {
|
||||
"prompt": "...",
|
||||
"negativePrompt": "...",
|
||||
"steps": 20,
|
||||
"cfgScale": 7,
|
||||
"seed": 123456
|
||||
},
|
||||
"hasMeta": true,
|
||||
"hasPositivePrompt": true
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `url` | string | Empty for local custom images |
|
||||
| `id` | string | Short ID or filename |
|
||||
| `nsfwLevel` | integer | NSFW level (bitmask) |
|
||||
| `width` | integer | Image width in pixels |
|
||||
| `height` | integer | Image height in pixels |
|
||||
| `type` | string | `"image"` or `"video"` |
|
||||
| `meta` | object\|null | Generation metadata (prompt, seed, etc.) extracted from image |
|
||||
| `hasMeta` | boolean | Whether metadata is available |
|
||||
| `hasPositivePrompt` | boolean | Whether a positive prompt is available |
|
||||
|
||||
### Minimal Non-Civitai Example
|
||||
|
||||
```json
|
||||
{
|
||||
"civitai": {
|
||||
"trainedWords": ["my_trigger_word"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Non-Civitai Example Without Trigger Words
|
||||
|
||||
```json
|
||||
{
|
||||
"civitai": {}
|
||||
}
|
||||
```
|
||||
|
||||
### Example: User-Added Custom Images
|
||||
|
||||
```json
|
||||
{
|
||||
"civitai": {
|
||||
"trainedWords": ["custom_style"],
|
||||
"customImages": [
|
||||
{
|
||||
"url": "",
|
||||
"id": "example_1",
|
||||
"nsfwLevel": 0,
|
||||
"width": 832,
|
||||
"height": 1216,
|
||||
"type": "image",
|
||||
"meta": {
|
||||
"prompt": "example prompt",
|
||||
"seed": 12345
|
||||
},
|
||||
"hasMeta": true,
|
||||
"hasPositivePrompt": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Metadata Source Values
|
||||
|
||||
The `metadata_source` field indicates which provider last updated the metadata:
|
||||
|
||||
| Value | Source |
|
||||
|-------|--------|
|
||||
| `"civitai_api"` | Civitai API |
|
||||
| `"civarchive"` | CivArchive API |
|
||||
| `"archive_db"` | Metadata Archive Database |
|
||||
| `null` | No external source (user-defined only) |
|
||||
|
||||
---
|
||||
|
||||
## Auto-Update Behavior
|
||||
|
||||
### Fields Updated During Scanning
|
||||
|
||||
These fields are automatically synchronized with the filesystem:
|
||||
|
||||
- `file_name` — Updated if actual filename differs
|
||||
- `file_path` — Normalized and updated if path changes
|
||||
- `preview_url` — Updated if preview file is moved/removed
|
||||
- `sha256` — Updated during hash calculation (when `hash_status="pending"`)
|
||||
- `hash_status` — Updated during hash calculation
|
||||
- `last_checked_at` — Timestamp of scan
|
||||
- `metadata_source` — Set based on metadata provider
|
||||
|
||||
### Fields Set Once (Immutable After Import)
|
||||
|
||||
These fields are set when the model is first imported/scanned and **never change** thereafter:
|
||||
|
||||
- `modified` — Import timestamp (used for "Date Added" sorting)
|
||||
- `size` — File size at time of import/download
|
||||
|
||||
### User-Editable Fields
|
||||
|
||||
These fields can be edited by users at any time through the Lora Manager UI or by manually editing the metadata.json file:
|
||||
|
||||
- `model_name` — Display name
|
||||
- `tags` — Model tags
|
||||
- `modelDescription` — Model description
|
||||
- `notes` — User notes
|
||||
- `favorite` — Favorite flag
|
||||
- `exclude` — Exclude from scanning (user can set `false`→`true`, currently no UI to revert)
|
||||
- `skip_metadata_refresh` — Skip during bulk refresh
|
||||
- `civitai.trainedWords` — Trigger words
|
||||
- `civitai.customImages` — Custom example images
|
||||
- `usage_tips` — Usage recommendations (LoRA only)
|
||||
|
||||
---
|
||||
|
||||
|
||||
## Field Reference by Behavior
|
||||
|
||||
### Required Fields (Must Always Exist)
|
||||
|
||||
- `file_name`
|
||||
- `model_name` (defaults to `file_name` if not provided)
|
||||
- `file_path`
|
||||
- `size`
|
||||
- `modified`
|
||||
- `sha256` (LoRA: always required; Checkpoint: may be empty when `hash_status="pending"`)
|
||||
|
||||
### Optional Fields with Defaults
|
||||
|
||||
| Field | Default |
|
||||
|-------|---------|
|
||||
| `base_model` | `"Unknown"` or `""` |
|
||||
| `preview_nsfw_level` | `0` |
|
||||
| `from_civitai` | `true` |
|
||||
| `civitai` | `{}` |
|
||||
| `tags` | `[]` |
|
||||
| `modelDescription` | `""` |
|
||||
| `notes` | `""` |
|
||||
| `civitai_deleted` | `false` |
|
||||
| `favorite` | `false` |
|
||||
| `exclude` | `false` |
|
||||
| `db_checked` | `false` |
|
||||
| `skip_metadata_refresh` | `false` |
|
||||
| `metadata_source` | `null` |
|
||||
| `last_checked_at` | `0` |
|
||||
| `hash_status` | `"completed"` |
|
||||
| `usage_tips` | `"{}"` (LoRA only) |
|
||||
| `model_type` | `"checkpoint"` or `"embedding"` (not present in LoRA models) |
|
||||
|
||||
---
|
||||
|
||||
## Version History
|
||||
|
||||
| Version | Date | Changes |
|
||||
|---------|------|---------|
|
||||
| 1.0 | 2026-03 | Initial schema documentation |
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
- [JSON Schema Definition](../.specs/metadata.schema.json) — Formal JSON Schema for validation
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
635
locales/de.json
635
locales/de.json
File diff suppressed because it is too large
Load Diff
3801
locales/en.json
3801
locales/en.json
File diff suppressed because it is too large
Load Diff
635
locales/es.json
635
locales/es.json
File diff suppressed because it is too large
Load Diff
635
locales/fr.json
635
locales/fr.json
File diff suppressed because it is too large
Load Diff
635
locales/he.json
635
locales/he.json
File diff suppressed because it is too large
Load Diff
635
locales/ja.json
635
locales/ja.json
File diff suppressed because it is too large
Load Diff
635
locales/ko.json
635
locales/ko.json
File diff suppressed because it is too large
Load Diff
635
locales/ru.json
635
locales/ru.json
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
3
package-lock.json
generated
3
package-lock.json
generated
@@ -114,7 +114,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
@@ -138,7 +137,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
@@ -1613,7 +1611,6 @@
|
||||
"integrity": "sha512-MyL55p3Ut3cXbeBEG7Hcv0mVM8pp8PBNWxRqchZnSfAiES1v1mRnMeFfaHWIPULpwsYfvO+ZmMZz5tGCnjzDUQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"cssstyle": "^4.0.1",
|
||||
"data-urls": "^5.0.0",
|
||||
|
||||
225
py/agent_cli/__init__.py
Normal file
225
py/agent_cli/__init__.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Agent CLI — thin in-process wrappers around LoRA Manager internal services.
|
||||
|
||||
All functions are simple Python async functions that delegate to the
|
||||
appropriate internal service. They use **relative imports** within the
|
||||
``py`` package, so ``sys.modules`` caching works normally and there is no
|
||||
risk of double import or circular dependencies.
|
||||
|
||||
Usage (in-process, primary)::
|
||||
|
||||
from py.agent_cli import list_base_models, read_metadata
|
||||
|
||||
models = await list_base_models()
|
||||
meta = await read_metadata("/path/to/model.safetensors")
|
||||
|
||||
Usage (subprocess, debugging / external)::
|
||||
|
||||
python -m py.agent_cli base-models list
|
||||
python -m py.agent_cli metadata read /path/to/model.safetensors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _find_scanner_for_model(
|
||||
model_path: str,
|
||||
) -> tuple[object, object] | tuple[None, None]:
|
||||
"""Find the (scanner, cache_entry) responsible for *model_path*.
|
||||
|
||||
Iterates all known scanner types and returns the first one whose cache
|
||||
contains the given path. Returns ``(None, None)`` when no scanner
|
||||
claims the model.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
normalized = os.path.normpath(model_path)
|
||||
for getter_name in (
|
||||
"get_lora_scanner",
|
||||
"get_checkpoint_scanner",
|
||||
"get_embedding_scanner",
|
||||
):
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
try:
|
||||
scanner = await getter()
|
||||
if scanner is None:
|
||||
continue
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
if os.path.normpath(entry.get("file_path", "")) == normalized:
|
||||
return scanner, entry
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Scanner %s check failed for %s: %s",
|
||||
getter_name,
|
||||
model_path,
|
||||
exc,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_base_models(limit: int = 0) -> List[str]:
|
||||
"""Return deduplicated base model names from all model caches.
|
||||
|
||||
The result is ordered by frequency (most common first). Pass
|
||||
*limit* = 0 (default) for all models.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
counts: Dict[str, int] = {}
|
||||
for getter_name in (
|
||||
"get_lora_scanner",
|
||||
"get_checkpoint_scanner",
|
||||
"get_embedding_scanner",
|
||||
):
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
try:
|
||||
scanner = await getter()
|
||||
if scanner is None:
|
||||
continue
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
bm = entry.get("base_model")
|
||||
if bm:
|
||||
counts[bm] = counts.get(bm, 0) + 1
|
||||
except Exception as exc:
|
||||
logger.debug("list_base_models scanner %s error: %s", getter_name, exc)
|
||||
|
||||
sorted_names = [name for name, _ in sorted(counts.items(), key=lambda x: -x[1])]
|
||||
if limit > 0:
|
||||
return sorted_names[:limit]
|
||||
return sorted_names
|
||||
|
||||
|
||||
async def read_metadata(model_path: str) -> Dict[str, Any]:
|
||||
"""Load the full metadata payload for *model_path* from disk.
|
||||
|
||||
Returns an empty dict when the metadata file does not exist or cannot
|
||||
be parsed — never raises.
|
||||
"""
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
try:
|
||||
return await MetadataManager.load_metadata_payload(model_path) or {}
|
||||
except Exception as exc:
|
||||
logger.warning("read_metadata failed for %s: %s", model_path, exc)
|
||||
return {}
|
||||
|
||||
|
||||
async def apply_metadata_updates(
|
||||
model_path: str,
|
||||
updates: Dict[str, Any],
|
||||
) -> List[str]:
|
||||
"""Merge *updates* into the model's on-disk metadata and persist.
|
||||
|
||||
Returns the list of field names that actually changed.
|
||||
"""
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
metadata = await read_metadata(model_path)
|
||||
updated_fields: List[str] = []
|
||||
for key, value in updates.items():
|
||||
old = metadata.get(key)
|
||||
if old != value:
|
||||
metadata[key] = value
|
||||
updated_fields.append(key)
|
||||
if updated_fields:
|
||||
await MetadataManager.save_metadata(model_path, metadata)
|
||||
return updated_fields
|
||||
|
||||
|
||||
async def download_preview(
|
||||
model_path: str,
|
||||
url: str,
|
||||
*,
|
||||
target_width: int = 480,
|
||||
quality: int = 85,
|
||||
) -> bool:
|
||||
"""Download a preview image from *url*, optimise to .webp, and save it.
|
||||
|
||||
The output file is placed alongside the model file with a ``.webp``
|
||||
extension. Returns ``True`` on success.
|
||||
"""
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
|
||||
if not url or not url.strip():
|
||||
return False
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
preview_dir = os.path.dirname(model_path)
|
||||
output_path = os.path.join(preview_dir, base_name + ".webp")
|
||||
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Try in-memory download + optimise first
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
url, use_auth=False,
|
||||
)
|
||||
if success and content:
|
||||
try:
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=content,
|
||||
target_width=target_width,
|
||||
format="webp",
|
||||
quality=quality,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(optimized_data)
|
||||
logger.info("Preview downloaded and optimised for %s", model_path)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Preview optimisation failed, saving raw: %s", exc)
|
||||
# Fall through to raw save
|
||||
|
||||
# Fallback: download directly to file
|
||||
try:
|
||||
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
|
||||
if ok:
|
||||
logger.info("Preview downloaded (fallback) for %s", model_path)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_cache(model_path: str) -> bool:
|
||||
"""Invalidate and reload the scanner cache entry for *model_path*.
|
||||
|
||||
Returns ``True`` when the model was found and the cache was refreshed.
|
||||
"""
|
||||
scanner, entry = await _find_scanner_for_model(model_path)
|
||||
if scanner is None:
|
||||
logger.warning("refresh_cache: no scanner found for %s", model_path)
|
||||
return False
|
||||
try:
|
||||
metadata = await read_metadata(model_path)
|
||||
if not metadata:
|
||||
logger.warning("refresh_cache: no metadata for %s", model_path)
|
||||
return False
|
||||
await scanner.update_single_model_cache(model_path, model_path, metadata)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("refresh_cache failed for %s: %s", model_path, exc)
|
||||
return False
|
||||
118
py/agent_cli/__main__.py
Normal file
118
py/agent_cli/__main__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Subprocess entry point for AgentCLI (debugging / external use).
|
||||
|
||||
Usage::
|
||||
|
||||
python -m py.agent_cli base-models list [--limit N]
|
||||
python -m py.agent_cli metadata read <path>
|
||||
python -m py.agent_cli metadata update <path> --json '{...}'
|
||||
python -m py.agent_cli preview download <path> --url <url>
|
||||
python -m py.agent_cli cache refresh <path>
|
||||
|
||||
NOTE: This is an **optional** convenience wrapper. The primary consumer of
|
||||
AgentCLI is the :mod:`AgentService` (in-process). This entry point exists
|
||||
for manual debugging and future integration with subprocess-based agent
|
||||
frameworks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="lmcli", description="LoRA Manager Agent CLI")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# base-models list
|
||||
base_models = sub.add_parser("base-models", aliases=["bm"])
|
||||
base_models_cmds = base_models.add_subparsers(dest="subcommand", required=True)
|
||||
base_models_list = base_models_cmds.add_parser("list")
|
||||
base_models_list.add_argument(
|
||||
"--limit", type=int, default=0, help="Max number of models (0 = all)"
|
||||
)
|
||||
|
||||
# metadata read
|
||||
meta = sub.add_parser("metadata", aliases=["md"])
|
||||
meta_cmds = meta.add_subparsers(dest="subcommand", required=True)
|
||||
meta_read = meta_cmds.add_parser("read")
|
||||
meta_read.add_argument("path", type=str, help="Model file path")
|
||||
|
||||
# metadata update
|
||||
meta_update = meta_cmds.add_parser("update")
|
||||
meta_update.add_argument("path", type=str, help="Model file path")
|
||||
meta_update.add_argument(
|
||||
"--json",
|
||||
type=str,
|
||||
required=True,
|
||||
help='JSON object of fields to update, e.g. \'{"base_model": "SDXL 1.0"}\'',
|
||||
)
|
||||
|
||||
# preview download
|
||||
prev = sub.add_parser("preview", aliases=["pv"])
|
||||
prev_cmds = prev.add_subparsers(dest="subcommand", required=True)
|
||||
prev_dl = prev_cmds.add_parser("download")
|
||||
prev_dl.add_argument("path", type=str, help="Model file path")
|
||||
prev_dl.add_argument("--url", type=str, required=True, help="Preview image URL")
|
||||
|
||||
# cache refresh
|
||||
cache = sub.add_parser("cache")
|
||||
cache_cmds = cache.add_subparsers(dest="subcommand", required=True)
|
||||
cache_refresh = cache_cmds.add_parser("refresh")
|
||||
cache_refresh.add_argument("path", type=str, help="Model file path")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> Any:
|
||||
from . import ( # lazy import so startup is fast
|
||||
list_base_models,
|
||||
read_metadata,
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
cmd = args.command
|
||||
sub = args.subcommand
|
||||
|
||||
if cmd in ("base-models", "bm") and sub == "list":
|
||||
return await list_base_models(limit=args.limit)
|
||||
|
||||
if cmd in ("metadata", "md") and sub == "read":
|
||||
return await read_metadata(args.path)
|
||||
|
||||
if cmd in ("metadata", "md") and sub == "update":
|
||||
updates: Dict[str, Any] = json.loads(args.json)
|
||||
return await apply_metadata_updates(args.path, updates)
|
||||
|
||||
if cmd in ("preview", "pv") and sub == "download":
|
||||
return await download_preview(args.path, args.url)
|
||||
|
||||
if cmd == "cache" and sub == "refresh":
|
||||
return await refresh_cache(args.path)
|
||||
|
||||
raise ValueError(f"Unknown command: {cmd} {sub}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(_run(args))
|
||||
# Always print as JSON so callers can parse reliably
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
print(item)
|
||||
elif isinstance(result, dict):
|
||||
json.dump(result, sys.stdout, ensure_ascii=False, indent=2)
|
||||
print()
|
||||
else:
|
||||
print(json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
430
py/config.py
430
py/config.py
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import platform
|
||||
import posixpath
|
||||
import threading
|
||||
from pathlib import Path
|
||||
import folder_paths # type: ignore
|
||||
@@ -7,6 +8,8 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
|
||||
import logging
|
||||
import json
|
||||
import urllib.parse
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
import time
|
||||
|
||||
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
|
||||
@@ -25,6 +28,67 @@ standalone_mode = (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_root_identity(path: str) -> str:
|
||||
"""Normalize a root path for comparisons across slash styles."""
|
||||
|
||||
normalized = posixpath.normpath(path.strip().replace("\\", "/"))
|
||||
if len(normalized) >= 2 and normalized[1] == ":":
|
||||
return normalized.lower()
|
||||
return normalized
|
||||
|
||||
|
||||
def _resolve_valid_default_root(
|
||||
current: str, primary_paths: List[str], allowed_paths: List[str], name: str
|
||||
) -> str:
|
||||
"""Return a valid default root from the current primary/extra path set."""
|
||||
|
||||
valid_paths = [path for path in primary_paths if isinstance(path, str) and path.strip()]
|
||||
fallback_paths: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
for path in allowed_paths:
|
||||
if not isinstance(path, str):
|
||||
continue
|
||||
stripped = path.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
identity = _normalize_root_identity(stripped)
|
||||
if identity in seen:
|
||||
continue
|
||||
seen.add(identity)
|
||||
fallback_paths.append(stripped)
|
||||
|
||||
allowed = {_normalize_root_identity(path) for path in fallback_paths}
|
||||
|
||||
if current and _normalize_root_identity(current) in allowed:
|
||||
return current
|
||||
|
||||
if not valid_paths:
|
||||
if not fallback_paths:
|
||||
return ""
|
||||
if current:
|
||||
logger.info(
|
||||
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||
name,
|
||||
current,
|
||||
fallback_paths[0],
|
||||
)
|
||||
else:
|
||||
logger.info("Auto-setting %s to '%s'", name, fallback_paths[0])
|
||||
return fallback_paths[0]
|
||||
|
||||
if current:
|
||||
logger.info(
|
||||
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||
name,
|
||||
current,
|
||||
valid_paths[0],
|
||||
)
|
||||
else:
|
||||
logger.info("Auto-setting %s to '%s'", name, valid_paths[0])
|
||||
|
||||
return valid_paths[0]
|
||||
|
||||
|
||||
def _normalize_folder_paths_for_comparison(
|
||||
folder_paths: Mapping[str, Iterable[str]],
|
||||
) -> Dict[str, Set[str]]:
|
||||
@@ -109,6 +173,13 @@ class Config:
|
||||
self.extra_checkpoints_roots: List[str] = []
|
||||
self.extra_unet_roots: List[str] = []
|
||||
self.extra_embeddings_roots: List[str] = []
|
||||
self.recipes_path: str = ""
|
||||
|
||||
# Load extra folder paths from active library settings before symlink scan
|
||||
# so both primary and extra paths are discovered in a single pass.
|
||||
if not standalone_mode:
|
||||
self._load_extra_paths_from_settings()
|
||||
|
||||
# Scan symbolic links during initialization
|
||||
self._initialize_symlink_mappings()
|
||||
|
||||
@@ -116,6 +187,96 @@ class Config:
|
||||
# Save the paths to settings.json when running in ComfyUI mode
|
||||
self.save_folder_paths_to_settings()
|
||||
|
||||
def _load_extra_paths_from_settings(self) -> None:
|
||||
"""Read extra folder paths from the active library and apply them.
|
||||
|
||||
Called during ``Config.__init__`` before the symlink scan so both primary and
|
||||
extra paths are discovered in a single pass. Mirrors the extra-path
|
||||
portion of ``_apply_library_paths`` without replacing the primary roots
|
||||
that were already resolved from ComfyUI's ``folder_paths``.
|
||||
"""
|
||||
try:
|
||||
from .services.settings_manager import get_settings_manager
|
||||
|
||||
settings_manager = get_settings_manager()
|
||||
library_name = settings_manager.get_active_library_name()
|
||||
libraries = settings_manager.get_libraries()
|
||||
|
||||
if not library_name or library_name not in libraries:
|
||||
return
|
||||
|
||||
library_config = libraries[library_name]
|
||||
if not isinstance(library_config, dict):
|
||||
return
|
||||
|
||||
extra_folder_paths = library_config.get("extra_folder_paths")
|
||||
if not isinstance(extra_folder_paths, dict):
|
||||
return
|
||||
|
||||
extra_lora = extra_folder_paths.get("loras", []) or []
|
||||
extra_checkpoint = extra_folder_paths.get("checkpoints", []) or []
|
||||
extra_unet = extra_folder_paths.get("unet", []) or []
|
||||
extra_embedding = extra_folder_paths.get("embeddings", []) or []
|
||||
|
||||
if not any([extra_lora, extra_checkpoint, extra_unet, extra_embedding]):
|
||||
return
|
||||
|
||||
filtered_extra_lora = self._filter_overlapping_extra_lora_paths(
|
||||
self.loras_roots, extra_lora
|
||||
)
|
||||
self.extra_loras_roots = self._prepare_lora_paths(filtered_extra_lora)
|
||||
(
|
||||
_,
|
||||
self.extra_checkpoints_roots,
|
||||
self.extra_unet_roots,
|
||||
) = self._prepare_checkpoint_paths(extra_checkpoint, extra_unet)
|
||||
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
||||
extra_embedding
|
||||
)
|
||||
|
||||
recipes_path = library_config.get("recipes_path", "")
|
||||
if isinstance(recipes_path, str) and recipes_path:
|
||||
self.recipes_path = recipes_path
|
||||
|
||||
if self.extra_loras_roots:
|
||||
logger.info(
|
||||
"Found extra LoRA roots:"
|
||||
+ "\n - "
|
||||
+ "\n - ".join(self.extra_loras_roots)
|
||||
)
|
||||
if self.extra_checkpoints_roots:
|
||||
logger.info(
|
||||
"Found extra checkpoint roots:"
|
||||
+ "\n - "
|
||||
+ "\n - ".join(self.extra_checkpoints_roots)
|
||||
)
|
||||
if self.extra_unet_roots:
|
||||
logger.info(
|
||||
"Found extra diffusion model roots:"
|
||||
+ "\n - "
|
||||
+ "\n - ".join(self.extra_unet_roots)
|
||||
)
|
||||
if self.extra_embeddings_roots:
|
||||
logger.info(
|
||||
"Found extra embedding roots:"
|
||||
+ "\n - "
|
||||
+ "\n - ".join(self.extra_embeddings_roots)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Applied library settings for '%s' with extra paths: loras=%s, "
|
||||
"checkpoints=%s, embeddings=%s",
|
||||
library_name,
|
||||
extra_lora,
|
||||
extra_checkpoint,
|
||||
extra_embedding,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Could not load extra paths from library settings: %s", exc
|
||||
)
|
||||
|
||||
def save_folder_paths_to_settings(self):
|
||||
"""Persist ComfyUI-derived folder paths to the multi-library settings."""
|
||||
try:
|
||||
@@ -197,44 +358,79 @@ class Config:
|
||||
"Failed to rename legacy 'default' library: %s", rename_error
|
||||
)
|
||||
|
||||
default_lora_root = comfy_library.get("default_lora_root", "")
|
||||
if not default_lora_root and len(self.loras_roots) == 1:
|
||||
default_lora_root = self.loras_roots[0]
|
||||
default_lora_root = _resolve_valid_default_root(
|
||||
comfy_library.get("default_lora_root", ""),
|
||||
list(self.loras_roots or []),
|
||||
list(self.loras_roots or [])
|
||||
+ list(comfy_library.get("extra_folder_paths", {}).get("loras", []) or []),
|
||||
"default_lora_root",
|
||||
)
|
||||
|
||||
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
|
||||
if (
|
||||
not default_checkpoint_root
|
||||
and self.checkpoints_roots
|
||||
and len(self.checkpoints_roots) == 1
|
||||
):
|
||||
default_checkpoint_root = self.checkpoints_roots[0]
|
||||
default_checkpoint_root = _resolve_valid_default_root(
|
||||
comfy_library.get("default_checkpoint_root", ""),
|
||||
list(self.checkpoints_roots or []),
|
||||
list(self.checkpoints_roots or [])
|
||||
+ list(comfy_library.get("extra_folder_paths", {}).get("checkpoints", []) or []),
|
||||
"default_checkpoint_root",
|
||||
)
|
||||
|
||||
default_embedding_root = comfy_library.get("default_embedding_root", "")
|
||||
if (
|
||||
not default_embedding_root
|
||||
and self.embeddings_roots
|
||||
and len(self.embeddings_roots) == 1
|
||||
):
|
||||
default_embedding_root = self.embeddings_roots[0]
|
||||
default_embedding_root = _resolve_valid_default_root(
|
||||
comfy_library.get("default_embedding_root", ""),
|
||||
list(self.embeddings_roots or []),
|
||||
list(self.embeddings_roots or [])
|
||||
+ list(comfy_library.get("extra_folder_paths", {}).get("embeddings", []) or []),
|
||||
"default_embedding_root",
|
||||
)
|
||||
|
||||
metadata = dict(comfy_library.get("metadata", {}))
|
||||
metadata.setdefault("display_name", "ComfyUI")
|
||||
metadata["source"] = "comfyui"
|
||||
extra_folder_paths = {}
|
||||
if isinstance(comfy_library, Mapping):
|
||||
existing_extra_paths = comfy_library.get("extra_folder_paths", {})
|
||||
if isinstance(existing_extra_paths, Mapping):
|
||||
extra_folder_paths = {
|
||||
key: list(value) if isinstance(value, list) else []
|
||||
for key, value in existing_extra_paths.items()
|
||||
}
|
||||
|
||||
active_library_name = settings_service.get_active_library_name()
|
||||
should_activate = (
|
||||
active_library_name == "comfyui"
|
||||
or self._should_activate_comfy_library(libraries, libraries_changed)
|
||||
)
|
||||
|
||||
settings_service.upsert_library(
|
||||
"comfyui",
|
||||
folder_paths=target_folder_paths,
|
||||
extra_folder_paths=extra_folder_paths,
|
||||
default_lora_root=default_lora_root,
|
||||
default_checkpoint_root=default_checkpoint_root,
|
||||
default_embedding_root=default_embedding_root,
|
||||
metadata=metadata,
|
||||
activate=True,
|
||||
activate=should_activate,
|
||||
)
|
||||
|
||||
logger.info("Updated 'comfyui' library with current folder paths")
|
||||
if should_activate:
|
||||
logger.info("Updated 'comfyui' library with current folder paths")
|
||||
else:
|
||||
logger.info(
|
||||
"Updated 'comfyui' library with current folder paths without activating it"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
def _should_activate_comfy_library(
|
||||
self, libraries: Mapping[str, Any], libraries_changed: bool
|
||||
) -> bool:
|
||||
"""Return whether startup sync should make the ComfyUI library active."""
|
||||
|
||||
if libraries_changed:
|
||||
return True
|
||||
if not libraries:
|
||||
return True
|
||||
return "comfyui" in libraries and len(libraries) == 1
|
||||
|
||||
def _is_link(self, path: str) -> bool:
|
||||
try:
|
||||
if os.path.islink(path):
|
||||
@@ -629,6 +825,8 @@ class Config:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.extra_embeddings_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
if self.recipes_path:
|
||||
preview_roots.update(self._expand_preview_root(self.recipes_path))
|
||||
|
||||
for target, link in self._path_mappings.items():
|
||||
preview_roots.update(self._expand_preview_root(target))
|
||||
@@ -705,9 +903,131 @@ class Config:
|
||||
|
||||
return unique_paths
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path_for_comparison(
|
||||
path: str, *, resolve_realpath: bool = False
|
||||
) -> str:
|
||||
"""Normalize a path for equality checks across platforms."""
|
||||
candidate = os.path.realpath(path) if resolve_realpath else path
|
||||
return os.path.normcase(os.path.normpath(candidate)).replace(os.sep, "/")
|
||||
|
||||
def _filter_overlapping_extra_lora_paths(
|
||||
self,
|
||||
primary_paths: Iterable[str],
|
||||
extra_paths: Iterable[str],
|
||||
) -> List[str]:
|
||||
"""Drop extra LoRA paths that resolve to the same physical location as primary roots."""
|
||||
|
||||
primary_map = {
|
||||
self._normalize_path_for_comparison(path, resolve_realpath=True): path
|
||||
for path in primary_paths
|
||||
if isinstance(path, str) and path.strip() and os.path.exists(path)
|
||||
}
|
||||
primary_symlink_map = self._collect_first_level_symlink_targets(primary_paths)
|
||||
filtered: List[str] = []
|
||||
|
||||
for original_path in extra_paths:
|
||||
if not isinstance(original_path, str):
|
||||
continue
|
||||
|
||||
stripped = original_path.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
if not os.path.exists(stripped):
|
||||
continue
|
||||
|
||||
real_path = self._normalize_path_for_comparison(
|
||||
stripped,
|
||||
resolve_realpath=True,
|
||||
)
|
||||
normalized_path = os.path.normpath(stripped).replace(os.sep, "/")
|
||||
primary_path = primary_map.get(real_path)
|
||||
if primary_path:
|
||||
# Config loading should stay tolerant of existing invalid state and warn.
|
||||
logger.warning(
|
||||
"Detected the same LoRA folder in both ComfyUI model paths and "
|
||||
"LoRA Manager Extra Folder Paths. This can cause duplicate items or "
|
||||
"other unexpected behavior, and it usually means the path setup is "
|
||||
"not doing what you intended. LoRA Manager will keep the ComfyUI "
|
||||
"path and ignore this Extra Folder Paths entry: '%s'. Please review "
|
||||
"your path settings and remove the duplicate entry.",
|
||||
normalized_path,
|
||||
)
|
||||
continue
|
||||
|
||||
symlink_path = primary_symlink_map.get(real_path)
|
||||
if symlink_path:
|
||||
# Config loading should stay tolerant of existing invalid state and warn.
|
||||
logger.warning(
|
||||
"Detected the same LoRA folder in both ComfyUI model paths and "
|
||||
"LoRA Manager Extra Folder Paths. This can cause duplicate items or "
|
||||
"other unexpected behavior, and it usually means the path setup is "
|
||||
"not doing what you intended. LoRA Manager will keep the ComfyUI "
|
||||
"path and ignore this Extra Folder Paths entry: '%s'. Please review "
|
||||
"your path settings and remove the duplicate entry.",
|
||||
normalized_path,
|
||||
)
|
||||
continue
|
||||
|
||||
filtered.append(stripped)
|
||||
|
||||
return filtered
|
||||
|
||||
def _collect_first_level_symlink_targets(
|
||||
self, roots: Iterable[str]
|
||||
) -> Dict[str, str]:
|
||||
"""Return real-path -> link-path mappings for first-level symlinks under the given roots."""
|
||||
|
||||
targets: Dict[str, str] = {}
|
||||
for root in roots:
|
||||
if not isinstance(root, str):
|
||||
continue
|
||||
stripped_root = root.strip()
|
||||
if not stripped_root or not os.path.isdir(stripped_root):
|
||||
continue
|
||||
|
||||
try:
|
||||
with os.scandir(stripped_root) as iterator:
|
||||
for entry in iterator:
|
||||
try:
|
||||
if not self._entry_is_symlink(entry):
|
||||
continue
|
||||
target_path = os.path.realpath(entry.path)
|
||||
if not os.path.isdir(target_path):
|
||||
continue
|
||||
|
||||
normalized_target = self._normalize_path_for_comparison(
|
||||
target_path,
|
||||
resolve_realpath=True,
|
||||
)
|
||||
normalized_link = os.path.normpath(entry.path).replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
targets.setdefault(normalized_target, normalized_link)
|
||||
except Exception as inner_exc:
|
||||
logger.debug(
|
||||
"Error collecting LoRA symlink target for %s: %s",
|
||||
entry.path,
|
||||
inner_exc,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Error scanning first-level LoRA symlinks in %s: %s",
|
||||
stripped_root,
|
||||
exc,
|
||||
)
|
||||
|
||||
return targets
|
||||
|
||||
def _prepare_checkpoint_paths(
|
||||
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
|
||||
) -> List[str]:
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
|
||||
|
||||
Returns:
|
||||
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
|
||||
This method does NOT modify instance variables - callers must set them.
|
||||
"""
|
||||
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
|
||||
unet_map = self._dedupe_existing_paths(unet_paths)
|
||||
|
||||
@@ -737,8 +1057,8 @@ class Config:
|
||||
|
||||
checkpoint_values = set(checkpoint_map.values())
|
||||
unet_values = set(unet_map.values())
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||
@@ -747,7 +1067,7 @@ class Config:
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
return unique_paths, checkpoint_roots, unet_roots
|
||||
|
||||
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
||||
path_map = self._dedupe_existing_paths(raw_paths)
|
||||
@@ -766,9 +1086,11 @@ class Config:
|
||||
self,
|
||||
folder_paths: Mapping[str, Iterable[str]],
|
||||
extra_folder_paths: Optional[Mapping[str, Iterable[str]]] = None,
|
||||
recipes_path: str = "",
|
||||
) -> None:
|
||||
self._path_mappings.clear()
|
||||
self._preview_root_paths = set()
|
||||
self.recipes_path = recipes_path if isinstance(recipes_path, str) else ""
|
||||
|
||||
lora_paths = folder_paths.get("loras", []) or []
|
||||
checkpoint_paths = folder_paths.get("checkpoints", []) or []
|
||||
@@ -776,9 +1098,11 @@ class Config:
|
||||
embedding_paths = folder_paths.get("embeddings", []) or []
|
||||
|
||||
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
||||
self.base_models_roots = self._prepare_checkpoint_paths(
|
||||
checkpoint_paths, unet_paths
|
||||
)
|
||||
(
|
||||
self.base_models_roots,
|
||||
self.checkpoints_roots,
|
||||
self.unet_roots,
|
||||
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
|
||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||
|
||||
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
||||
@@ -788,19 +1112,16 @@ class Config:
|
||||
extra_unet_paths = extra_paths.get("unet", []) or []
|
||||
extra_embedding_paths = extra_paths.get("embeddings", []) or []
|
||||
|
||||
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
||||
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
|
||||
saved_checkpoints_roots = self.checkpoints_roots
|
||||
saved_unet_roots = self.unet_roots
|
||||
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
|
||||
extra_checkpoint_paths, extra_unet_paths
|
||||
filtered_extra_lora_paths = self._filter_overlapping_extra_lora_paths(
|
||||
self.loras_roots,
|
||||
extra_lora_paths,
|
||||
)
|
||||
self.extra_unet_roots = (
|
||||
self.unet_roots if self.unet_roots is not None else []
|
||||
) # unet_roots was set by _prepare_checkpoint_paths
|
||||
# Restore main paths
|
||||
self.checkpoints_roots = saved_checkpoints_roots
|
||||
self.unet_roots = saved_unet_roots
|
||||
self.extra_loras_roots = self._prepare_lora_paths(filtered_extra_lora_paths)
|
||||
(
|
||||
_,
|
||||
self.extra_checkpoints_roots,
|
||||
self.extra_unet_roots,
|
||||
) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
|
||||
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
||||
extra_embedding_paths
|
||||
)
|
||||
@@ -857,9 +1178,11 @@ class Config:
|
||||
try:
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
unique_paths = self._prepare_checkpoint_paths(
|
||||
raw_checkpoint_paths, raw_unet_paths
|
||||
)
|
||||
(
|
||||
unique_paths,
|
||||
self.checkpoints_roots,
|
||||
self.unet_roots,
|
||||
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
|
||||
|
||||
logger.info(
|
||||
"Found checkpoint roots:"
|
||||
@@ -1023,7 +1346,12 @@ class Config:
|
||||
if not isinstance(extra_folder_paths, Mapping):
|
||||
extra_folder_paths = None
|
||||
|
||||
self._apply_library_paths(folder_paths, extra_folder_paths)
|
||||
recipes_path = (
|
||||
str(library_config.get("recipes_path", ""))
|
||||
if isinstance(library_config, Mapping)
|
||||
else ""
|
||||
)
|
||||
self._apply_library_paths(folder_paths, extra_folder_paths, recipes_path)
|
||||
|
||||
logger.info(
|
||||
"Applied library settings with %d lora roots (%d extra), %d checkpoint roots (%d extra), and %d embedding roots (%d extra)",
|
||||
@@ -1054,4 +1382,20 @@ class Config:
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
# NOTE: Guard against re-import. When ServiceRegistry.get_lora_scanner() triggers
|
||||
# a fresh import of lora_scanner → config, we must NOT re-execute Config.__init__()
|
||||
# (which re-scans all roots, re-registers libraries, etc.).
|
||||
#
|
||||
# Strategy: store the config instance in a dedicated sentinel module
|
||||
# ('_lm_config_cache') that is NEVER removed from sys.modules (its key does
|
||||
# NOT start with 'py.'), so it survives re-imports of py.* modules.
|
||||
_CONFIG_SENTINEL = "_lm_config_cache"
|
||||
if _CONFIG_SENTINEL in _sys.modules:
|
||||
# Re-import: reuse the existing singleton from the sentinel.
|
||||
config: Config = _sys.modules[_CONFIG_SENTINEL].config # type: ignore[valid-type]
|
||||
else:
|
||||
config: Config = Config()
|
||||
# Register the sentinel so re-imports of py.config find us.
|
||||
_sentinel_mod = _types.ModuleType(_CONFIG_SENTINEL)
|
||||
_sentinel_mod.config = config
|
||||
_sys.modules[_CONFIG_SENTINEL] = _sentinel_mod
|
||||
|
||||
@@ -33,6 +33,7 @@ from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
from .middleware.csp_middleware import relax_csp_for_remote_media
|
||||
from .middleware.error_middleware import api_json_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -76,6 +77,11 @@ class LoraManager:
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
# Register JSON error middleware for /api/* routes as the outermost
|
||||
# middleware so it catches errors from all other middlewares.
|
||||
if api_json_error not in app.middlewares:
|
||||
app.middlewares.insert(0, api_json_error)
|
||||
|
||||
if relax_csp_for_remote_media not in app.middlewares:
|
||||
# Ensure CSP relaxer executes after ComfyUI's block_external_middleware so it can
|
||||
# see and extend the restrictive header instead of being overwritten by it.
|
||||
@@ -184,45 +190,17 @@ class LoraManager:
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
# Apply library settings to load extra folder paths before scanning
|
||||
# Only apply if extra paths haven't been loaded yet (preserves test mocks)
|
||||
try:
|
||||
from .services.settings_manager import get_settings_manager
|
||||
|
||||
settings_manager = get_settings_manager()
|
||||
library_name = settings_manager.get_active_library_name()
|
||||
libraries = settings_manager.get_libraries()
|
||||
if library_name and library_name in libraries:
|
||||
library_config = libraries[library_name]
|
||||
# Only apply settings if extra paths are not already configured
|
||||
# This preserves values set by tests via monkeypatch
|
||||
extra_paths = library_config.get("extra_folder_paths", {})
|
||||
has_extra_paths = (
|
||||
config.extra_loras_roots
|
||||
or config.extra_checkpoints_roots
|
||||
or config.extra_unet_roots
|
||||
or config.extra_embeddings_roots
|
||||
)
|
||||
if not has_extra_paths and any(extra_paths.values()):
|
||||
config.apply_library_settings(library_config)
|
||||
logger.info(
|
||||
"Applied library settings for '%s' with extra paths: loras=%s, checkpoints=%s, embeddings=%s",
|
||||
library_name,
|
||||
extra_paths.get("loras", []),
|
||||
extra_paths.get("checkpoints", []),
|
||||
extra_paths.get("embeddings", []),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to apply library settings during initialization: %s", exc
|
||||
)
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Register DownloadManager with ServiceRegistry
|
||||
await ServiceRegistry.get_download_manager()
|
||||
|
||||
# Initialize DownloadQueueService for persistent queue/history
|
||||
await ServiceRegistry.get_download_queue_service()
|
||||
|
||||
await ServiceRegistry.get_backup_service()
|
||||
|
||||
from .services.metadata_service import initialize_metadata_providers
|
||||
|
||||
await initialize_metadata_providers()
|
||||
@@ -458,5 +436,21 @@ class LoraManager:
|
||||
try:
|
||||
logger.info("LoRA Manager: Cleaning up services")
|
||||
|
||||
# Cancel any in-flight scanner initialization tasks so thread-pool
|
||||
# workers (e.g. _initialize_cache_sync) can break out of their loops
|
||||
# when the server shuts down (e.g. Ctrl+C on WSL).
|
||||
for name in ("lora_scanner", "checkpoint_scanner", "embedding_scanner"):
|
||||
scanner = ServiceRegistry.get_service_sync(name)
|
||||
if scanner is not None and hasattr(scanner, "cancel_task"):
|
||||
scanner.cancel_task()
|
||||
logger.debug("LoRA Manager: Cancelled %s", name)
|
||||
|
||||
# Close shared aiohttp sessions to avoid "Unclosed client session" warnings
|
||||
try:
|
||||
from py.routes.handlers.hf_handlers import close_hf_api_session
|
||||
await close_hf_api_session()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing HF API session: %s", exc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
@@ -5,9 +5,10 @@ MODELS = "models"
|
||||
PROMPTS = "prompts"
|
||||
SAMPLING = "sampling"
|
||||
LORAS = "loras"
|
||||
EMBEDDINGS = "embeddings"
|
||||
SIZE = "size"
|
||||
IMAGES = "images"
|
||||
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
|
||||
|
||||
# Complete list of categories to track
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, EMBEDDINGS, SIZE, IMAGES]
|
||||
|
||||
@@ -148,10 +148,13 @@ class MetadataHook:
|
||||
"""Install hooks for asynchronous execution model"""
|
||||
# Store the original _async_map_node_over_list function
|
||||
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||
|
||||
# Wrapped async function, compatible with both stable and nightly
|
||||
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
|
||||
hidden_inputs = kwargs.get('hidden_inputs', None)
|
||||
|
||||
# Wrapped async function - signature must exactly match _async_map_node_over_list
|
||||
async def async_map_node_over_list_with_metadata(
|
||||
prompt_id, unique_id, obj, input_data_all, func,
|
||||
allow_interrupt=False, execution_block_cb=None,
|
||||
pre_execute_cb=None, v3_data=None
|
||||
):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
@@ -163,13 +166,13 @@ class MetadataHook:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Call original function with all args/kwargs
|
||||
|
||||
# Call original function with exact parameters
|
||||
results = await original_map_node_over_list(
|
||||
prompt_id, unique_id, obj, input_data_all, func,
|
||||
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
|
||||
allow_interrupt, execution_block_cb, pre_execute_cb, v3_data=v3_data
|
||||
)
|
||||
|
||||
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
@@ -180,28 +183,28 @@ class MetadataHook:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
|
||||
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
|
||||
# Execute the original function
|
||||
return await original_execute(*args, **kwargs)
|
||||
|
||||
|
||||
# Replace the functions with async versions
|
||||
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||
execution.execute = async_execute_with_prompt_tracking
|
||||
|
||||
@@ -352,50 +352,101 @@ class MetadataProcessor:
|
||||
|
||||
# Check if we have stored conditioning objects for this sampler
|
||||
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]
|
||||
):
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
|
||||
def extend_unique(target, values):
|
||||
for value in values:
|
||||
if value and value not in target:
|
||||
target.append(value)
|
||||
|
||||
# Helper function to recursively find prompt texts for a conditioning object.
|
||||
# Transform nodes can map one output conditioning to multiple source conditionings.
|
||||
def find_prompt_texts_for_conditioning(
|
||||
conditioning_obj, is_positive=True, visited=None
|
||||
):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
|
||||
return []
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
conditioning_id = id(conditioning_obj)
|
||||
if conditioning_id in visited:
|
||||
return []
|
||||
visited.add(conditioning_id)
|
||||
|
||||
prompt_texts = []
|
||||
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
|
||||
return ""
|
||||
|
||||
if not isinstance(prompt_data, dict):
|
||||
continue
|
||||
|
||||
# For CLIP text nodes with a single conditioning output.
|
||||
if id(prompt_data.get("conditioning")) == conditioning_id:
|
||||
text = prompt_data.get("text", "")
|
||||
if text:
|
||||
extend_unique(prompt_texts, [text])
|
||||
|
||||
# Generic provenance for passthrough/transform/combine nodes.
|
||||
for source in prompt_data.get("conditioning_sources", []):
|
||||
if id(source.get("output")) != conditioning_id:
|
||||
continue
|
||||
for input_conditioning in source.get("inputs", []):
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
input_conditioning, is_positive, visited
|
||||
),
|
||||
)
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs
|
||||
# like TSC_EfficientLoader and existing ControlNet-style metadata.
|
||||
if (
|
||||
is_positive
|
||||
and id(prompt_data.get("positive_encoded")) == conditioning_id
|
||||
):
|
||||
if prompt_data.get("positive_text"):
|
||||
extend_unique(prompt_texts, [prompt_data["positive_text"]])
|
||||
else:
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
prompt_data.get("orig_pos_cond"),
|
||||
is_positive=True,
|
||||
visited=visited,
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
not is_positive
|
||||
and id(prompt_data.get("negative_encoded")) == conditioning_id
|
||||
):
|
||||
if prompt_data.get("negative_text"):
|
||||
extend_unique(prompt_texts, [prompt_data["negative_text"]])
|
||||
else:
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
prompt_data.get("orig_neg_cond"),
|
||||
is_positive=False,
|
||||
visited=visited,
|
||||
),
|
||||
)
|
||||
|
||||
return prompt_texts
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
result["prompt"] = ", ".join(
|
||||
find_prompt_texts_for_conditioning(pos_conditioning, is_positive=True)
|
||||
)
|
||||
result["negative_prompt"] = ", ".join(
|
||||
find_prompt_texts_for_conditioning(neg_conditioning, is_positive=False)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -509,8 +560,14 @@ class MetadataProcessor:
|
||||
|
||||
params["loras"] = " ".join(lora_parts)
|
||||
|
||||
# Set default clip_skip value
|
||||
params["clip_skip"] = "1" # Common default
|
||||
# Extract clip_skip from any SAMPLING node that provides it
|
||||
for sampler_info in metadata.get(SAMPLING, {}).values():
|
||||
clip_skip = sampler_info.get("parameters", {}).get("clip_skip")
|
||||
if clip_skip is not None:
|
||||
params["clip_skip"] = clip_skip
|
||||
break
|
||||
if params["clip_skip"] is None:
|
||||
params["clip_skip"] = "1"
|
||||
|
||||
return params
|
||||
|
||||
@@ -595,6 +652,15 @@ class MetadataProcessor:
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
# Generic guider nodes often expose separate positive/negative inputs.
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", max_depth=10)
|
||||
if not positive_node_id:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", max_depth=10)
|
||||
if not negative_node_id:
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER
|
||||
|
||||
@@ -142,6 +144,118 @@ class TSCCheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = positive_conditioning
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = negative_conditioning
|
||||
|
||||
|
||||
class EasyComfyLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
if "ckpt_name" in inputs:
|
||||
_store_checkpoint_metadata(metadata, node_id, inputs["ckpt_name"])
|
||||
|
||||
# Only extract from optional_lora_stack — skip the single lora_name to
|
||||
# avoid double-counting LoRAs that come through the LORA_STACK path.
|
||||
active_loras = []
|
||||
optional_lora_stack = inputs.get("optional_lora_stack")
|
||||
if optional_lora_stack is not None and isinstance(optional_lora_stack, (list, tuple)):
|
||||
for item in optional_lora_stack:
|
||||
if isinstance(item, (list, tuple)) and len(item) >= 2:
|
||||
lora_path = item[0]
|
||||
model_strength = item[1]
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": model_strength
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
metadata[PROMPTS][node_id]["positive_text"] = positive_text
|
||||
metadata[PROMPTS][node_id]["negative_text"] = negative_text
|
||||
|
||||
if "clip_skip" in inputs:
|
||||
clip_skip = inputs["clip_skip"]
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
metadata[SAMPLING][node_id]["parameters"]["clip_skip"] = clip_skip
|
||||
|
||||
width = inputs.get("empty_latent_width")
|
||||
height = inputs.get("empty_latent_height")
|
||||
if width is not None and height is not None:
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": int(width),
|
||||
"height": int(height),
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# outputs: [(pipe_dict, model, vae), ...]
|
||||
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return
|
||||
first_output = outputs[0]
|
||||
if not isinstance(first_output, tuple) or len(first_output) < 1:
|
||||
return
|
||||
pipe = first_output[0]
|
||||
if not isinstance(pipe, dict):
|
||||
return
|
||||
|
||||
positive_conditioning = pipe.get("positive")
|
||||
negative_conditioning = pipe.get("negative")
|
||||
|
||||
if positive_conditioning is not None or negative_conditioning is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
if positive_conditioning is not None:
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = positive_conditioning
|
||||
if negative_conditioning is not None:
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = negative_conditioning
|
||||
|
||||
|
||||
class EasyPreSamplingExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
for key in ("steps", "cfg", "sampler_name", "scheduler", "denoise", "seed"):
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True
|
||||
}
|
||||
|
||||
|
||||
class EasySeedExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "seed" not in inputs:
|
||||
return
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": {"seed": inputs["seed"]},
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: False
|
||||
}
|
||||
|
||||
|
||||
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -161,6 +275,251 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
|
||||
class MyOriginalWaifuTextExtractor(NodeMetadataExtractor):
|
||||
"""Extractor for ComfyUI-MyOriginalWaifu TextProvider nodes."""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"positive_text": positive_text,
|
||||
"negative_text": negative_text,
|
||||
"node_id": node_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 2:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["positive_text"] = output_tuple[0]
|
||||
prompt_metadata["negative_text"] = output_tuple[1]
|
||||
|
||||
|
||||
class MyOriginalWaifuClipExtractor(NodeMetadataExtractor):
|
||||
"""Extractor for ComfyUI-MyOriginalWaifu ClipProvider nodes."""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"positive_text": positive_text,
|
||||
"negative_text": negative_text,
|
||||
"node_id": node_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 2:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["positive_encoded"] = output_tuple[0]
|
||||
prompt_metadata["negative_encoded"] = output_tuple[1]
|
||||
|
||||
|
||||
def _ensure_prompt_metadata(metadata, node_id):
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
return metadata[PROMPTS][node_id]
|
||||
|
||||
|
||||
def _first_output_tuple(outputs):
|
||||
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return None
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple):
|
||||
return first_output
|
||||
return None
|
||||
|
||||
|
||||
def _record_conditioning_source(
|
||||
metadata, node_id, output_conditioning, input_conditionings
|
||||
):
|
||||
if output_conditioning is None:
|
||||
return
|
||||
|
||||
sources = [
|
||||
conditioning for conditioning in input_conditionings if conditioning is not None
|
||||
]
|
||||
if not sources:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata.setdefault("conditioning_sources", []).append(
|
||||
{
|
||||
"output": output_conditioning,
|
||||
"inputs": sources,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_variable_name(inputs):
|
||||
for key in ("key", "name", "variable_name", "tag", "text"):
|
||||
value = inputs.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _get_node_variable_name(metadata, node_id, inputs):
|
||||
variable_name = _get_variable_name(inputs)
|
||||
if variable_name:
|
||||
return variable_name
|
||||
|
||||
prompt = metadata.get("current_prompt")
|
||||
original_prompt = getattr(prompt, "original_prompt", None)
|
||||
if not original_prompt or node_id not in original_prompt:
|
||||
return None
|
||||
|
||||
node_data = original_prompt[node_id]
|
||||
variable_name = _get_variable_name(node_data.get("inputs", {}))
|
||||
if variable_name:
|
||||
return variable_name
|
||||
|
||||
widgets_values = node_data.get("widgets_values", [])
|
||||
if widgets_values and isinstance(widgets_values[0], str):
|
||||
return widgets_values[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetApplyAdvancedExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
if inputs.get("positive") is not None:
|
||||
prompt_metadata["orig_pos_cond"] = inputs["positive"]
|
||||
if inputs.get("negative") is not None:
|
||||
prompt_metadata["orig_neg_cond"] = inputs["negative"]
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
positive_input = prompt_metadata.get("orig_pos_cond")
|
||||
negative_input = prompt_metadata.get("orig_neg_cond")
|
||||
|
||||
if len(output_tuple) >= 1:
|
||||
prompt_metadata["positive_encoded"] = output_tuple[0]
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_tuple[0], [positive_input]
|
||||
)
|
||||
if len(output_tuple) >= 2:
|
||||
prompt_metadata["negative_encoded"] = output_tuple[1]
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_tuple[1], [negative_input]
|
||||
)
|
||||
|
||||
|
||||
class ConditioningCombineExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
input_conditionings = []
|
||||
for input_name in inputs:
|
||||
if (
|
||||
input_name.startswith("conditioning")
|
||||
and inputs[input_name] is not None
|
||||
):
|
||||
input_conditionings.append(inputs[input_name])
|
||||
|
||||
if input_conditionings:
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["orig_conditionings"] = input_conditionings
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 1:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
output_conditioning = output_tuple[0]
|
||||
prompt_metadata["conditioning"] = output_conditioning
|
||||
_record_conditioning_source(
|
||||
metadata,
|
||||
node_id,
|
||||
output_conditioning,
|
||||
prompt_metadata.get("orig_conditionings", []),
|
||||
)
|
||||
|
||||
|
||||
class SetNodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
variable_name = _get_node_variable_name(metadata, node_id, inputs)
|
||||
conditioning = inputs.get("CONDITIONING")
|
||||
if conditioning is None:
|
||||
conditioning = inputs.get("conditioning")
|
||||
if conditioning is None:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["conditioning"] = conditioning
|
||||
if variable_name:
|
||||
prompt_metadata["variable_name"] = variable_name
|
||||
metadata[PROMPTS].setdefault("__conditioning_variables__", {})[
|
||||
variable_name
|
||||
] = conditioning
|
||||
|
||||
|
||||
class GetNodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
variable_name = _get_node_variable_name(metadata, node_id, inputs or {})
|
||||
if variable_name:
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["variable_name"] = variable_name
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 1:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
output_conditioning = output_tuple[0]
|
||||
prompt_metadata["conditioning"] = output_conditioning
|
||||
|
||||
variable_name = prompt_metadata.get("variable_name")
|
||||
if not variable_name:
|
||||
return
|
||||
|
||||
input_conditioning = metadata[PROMPTS].get("__conditioning_variables__", {}).get(
|
||||
variable_name
|
||||
)
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_conditioning, [input_conditioning]
|
||||
)
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@@ -427,6 +786,75 @@ class ImageSizeExtractor(NodeMetadataExtractor):
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class RgthreePowerLoraLoaderExtractor(NodeMetadataExtractor):
|
||||
"""Extract LoRA metadata from rgthree Power Lora Loader.
|
||||
|
||||
The node passes LoRAs as dynamic kwargs: LORA_1, LORA_2, ... each containing
|
||||
{'on': bool, 'lora': filename, 'strength': float, 'strengthTwo': float}.
|
||||
"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
active_loras = []
|
||||
for key, value in inputs.items():
|
||||
if not key.upper().startswith('LORA_'):
|
||||
continue
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
if not value.get('on') or not value.get('lora'):
|
||||
continue
|
||||
lora_name = os.path.splitext(os.path.basename(value['lora']))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": round(float(value.get('strength', 1.0)), 2)
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
|
||||
class TensorRTLoaderExtractor(NodeMetadataExtractor):
|
||||
"""Extract checkpoint metadata from TensorRT Loader.
|
||||
|
||||
extract() parses the engine filename from 'unet_name' as a best-effort
|
||||
fallback (strips profile suffix after '_$' and counter suffix).
|
||||
|
||||
update() checks if the output MODEL has attachments["source_model"]
|
||||
set by the node (NubeBuster fork) and overrides with the real name.
|
||||
Vanilla TRT doesn't set this — the filename parse stands.
|
||||
"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "unet_name" not in inputs:
|
||||
return
|
||||
unet_name = inputs.get("unet_name")
|
||||
# Strip path and extension, then drop the $_profile suffix
|
||||
model_name = os.path.splitext(os.path.basename(unet_name))[0]
|
||||
if "_$" in model_name:
|
||||
model_name = model_name[:model_name.index("_$")]
|
||||
# Strip counter suffix (e.g. _00001_) left by ComfyUI's save path
|
||||
model_name = re.sub(r'_\d+_?$', '', model_name)
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return
|
||||
first_output = outputs[0]
|
||||
if not isinstance(first_output, tuple) or len(first_output) < 1:
|
||||
return
|
||||
model = first_output[0]
|
||||
# NubeBuster fork sets attachments["source_model"] on the ModelPatcher
|
||||
source_model = getattr(model, 'attachments', {}).get("source_model")
|
||||
if source_model:
|
||||
_store_checkpoint_metadata(metadata, node_id, source_model)
|
||||
|
||||
|
||||
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -473,6 +901,55 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class LoraTextLoaderManagerExtractor(NodeMetadataExtractor):
|
||||
"""Extract LoRA metadata from LoraTextLoaderLM (LoRA Text Loader).
|
||||
|
||||
The node accepts a `lora_syntax` STRING containing <lora:name:strength> tags
|
||||
(same format as the ComfyUI prompt), plus an optional `lora_stack`.
|
||||
This extractor parses the syntax string using the same regex as the node.
|
||||
"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
active_loras = []
|
||||
|
||||
# Process lora_stack if available (optional input)
|
||||
if "lora_stack" in inputs:
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
for item in lora_stack:
|
||||
# lora_stack entries are (path, model_strength, clip_strength) tuples
|
||||
if isinstance(item, (list, tuple)) and len(item) >= 2:
|
||||
lora_path = item[0]
|
||||
model_strength = item[1]
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": round(float(model_strength), 2)
|
||||
})
|
||||
|
||||
# Process lora_syntax string input
|
||||
if "lora_syntax" in inputs:
|
||||
lora_syntax = inputs.get("lora_syntax", "")
|
||||
if lora_syntax and isinstance(lora_syntax, str):
|
||||
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
|
||||
matches = re.findall(pattern, lora_syntax, re.IGNORECASE)
|
||||
for match in matches:
|
||||
lora_name = match[0]
|
||||
model_strength = float(match[1])
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": round(model_strength, 2)
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
|
||||
class FluxGuidanceExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -577,8 +1054,6 @@ class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -699,9 +1174,12 @@ NODE_EXTRACTORS = {
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||
# ComfyUI-Easy-Use pre-sampling / seed
|
||||
"samplerSettings": EasyPreSamplingExtractor, # easy preSampling
|
||||
"easySeed": EasySeedExtractor, # easy seed
|
||||
# Loaders
|
||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||
"comfyLoader": EasyComfyLoaderExtractor, # ComfyUI-Easy-Use easy comfyLoader
|
||||
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||
"NunchakuFluxDiTLoader": NunchakuFluxDiTLoaderExtractor, # ComfyUI-Nunchaku
|
||||
@@ -711,12 +1189,18 @@ NODE_EXTRACTORS = {
|
||||
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
"LoraLoaderLM": LoraLoaderManagerExtractor,
|
||||
"LoraTextLoaderLM": LoraTextLoaderManagerExtractor,
|
||||
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,
|
||||
"TensorRTLoader": TensorRTLoaderExtractor,
|
||||
# Conditioning
|
||||
"CLIPTextEncode": CLIPTextEncodeExtractor,
|
||||
"CLIPTextEncodeAttentionBias": CLIPTextEncodeExtractor, # From https://github.com/silveroxides/ComfyUI_PromptAttention
|
||||
"PromptLM": CLIPTextEncodeExtractor,
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||
@@ -724,6 +1208,12 @@ NODE_EXTRACTORS = {
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
|
||||
"TextProvider": MyOriginalWaifuTextExtractor, # ComfyUI-MyOriginalWaifu
|
||||
"ClipProvider": MyOriginalWaifuClipExtractor, # ComfyUI-MyOriginalWaifu
|
||||
"ControlNetApplyAdvanced": ControlNetApplyAdvancedExtractor,
|
||||
"ConditioningCombine": ConditioningCombineExtractor,
|
||||
"SetNode": SetNodeExtractor,
|
||||
"GetNode": GetNodeExtractor,
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
|
||||
@@ -16,6 +16,8 @@ IMG_EXTENSIONS = (
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
".avif",
|
||||
".jxl",
|
||||
".mp4"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,15 +4,21 @@ from typing import Awaitable, Callable, Dict, List
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
# Use wildcard for CivitAI to support their CDN subdomains (e.g., image-b2.civitai.com)
|
||||
# Security note: This is acceptable because:
|
||||
# 1. CSP img-src only controls image/video loading, not script execution
|
||||
# 2. All *.civitai.com subdomains are controlled by Civitai
|
||||
# 3. Explicit domain list would require constant updates as Civitai adds CDN nodes
|
||||
REMOTE_MEDIA_SOURCES = (
|
||||
"https://image.civitai.com",
|
||||
"https://*.civitai.com",
|
||||
"https://img.genur.art",
|
||||
)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def relax_csp_for_remote_media(
|
||||
request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]]
|
||||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
|
||||
) -> web.StreamResponse:
|
||||
"""Allow LoRA Manager media previews to load from trusted remote domains.
|
||||
|
||||
@@ -43,7 +49,9 @@ async def relax_csp_for_remote_media(
|
||||
directive_order.append(name)
|
||||
directives[name] = values
|
||||
|
||||
def merge_sources(name: str, sources: List[str], defaults: List[str] | None = None) -> None:
|
||||
def merge_sources(
|
||||
name: str, sources: List[str], defaults: List[str] | None = None
|
||||
) -> None:
|
||||
existing = directives.get(name, list(defaults or []))
|
||||
|
||||
for source in sources:
|
||||
|
||||
71
py/middleware/error_middleware.py
Normal file
71
py/middleware/error_middleware.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""JSON error middleware for API routes.
|
||||
|
||||
Ensures all responses to /api/* requests return valid JSON that the
|
||||
browser-extension frontend can JSON.parse() without crashing, even when
|
||||
the route does not exist (404) or the handler raises an exception (500).
|
||||
|
||||
Extension consumers call response.json() unconditionally — an HTML error
|
||||
page causes ``SyntaxError: unexpected end of data`` that leaks into the
|
||||
popup UI as a toast notification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def api_json_error(
|
||||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[web.Response]],
|
||||
) -> web.Response:
|
||||
"""Return JSON ``{"success": false, "error": "..."}`` for API errors.
|
||||
|
||||
Only intercepts paths starting with ``/api/`` — all other routes
|
||||
(frontend pages, static files, WebSocket upgrades) pass through
|
||||
unchanged.
|
||||
"""
|
||||
if not request.path.startswith("/api/"):
|
||||
return await handler(request)
|
||||
|
||||
try:
|
||||
response = await handler(request)
|
||||
return response
|
||||
except web.HTTPException as exc:
|
||||
# Let redirects (301, 302, 307, 308) propagate — they are not errors.
|
||||
if exc.status < 400:
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
"API %s %s returned HTTP %d: %s",
|
||||
request.method,
|
||||
request.path,
|
||||
exc.status,
|
||||
exc.reason,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{"success": False, "error": f"{exc.status}: {exc.reason}"},
|
||||
status=exc.status,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"API %s %s raised unhandled exception: %s",
|
||||
request.method,
|
||||
request.path,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"500: Internal Server Error ({type(exc).__name__})",
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
118
py/nodes/checkpoint_loader.py
Normal file
118
py/nodes/checkpoint_loader.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
import comfy.sd # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointLoaderLM:
|
||||
"""Checkpoint Loader with support for extra folder paths
|
||||
|
||||
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
|
||||
extra folder paths, providing a unified interface for checkpoint loading.
|
||||
"""
|
||||
|
||||
NAME = "Checkpoint Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
# Get list of checkpoint names from scanner (includes extra folder paths)
|
||||
checkpoint_names = s._get_checkpoint_names()
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (
|
||||
checkpoint_names,
|
||||
{"tooltip": "The name of the checkpoint (model) to load."},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||
)
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoint_names(cls) -> List[str]:
|
||||
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||
try:
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import asyncio
|
||||
|
||||
async def _get_names():
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get all model roots for calculating relative paths
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Filter only checkpoint type (not diffusion_model) and format names
|
||||
names = []
|
||||
for item in cache.raw_data:
|
||||
if item.get("sub_type") == "checkpoint":
|
||||
file_path = item.get("file_path", "")
|
||||
if file_path:
|
||||
# Format using relative path with OS-native separator
|
||||
formatted_name = _format_model_name_for_comfyui(
|
||||
file_path, model_roots
|
||||
)
|
||||
if formatted_name:
|
||||
names.append(formatted_name)
|
||||
|
||||
return sorted(names)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_names())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_get_names())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint names: {e}")
|
||||
return []
|
||||
|
||||
def load_checkpoint(self, ckpt_name: str) -> Tuple:
|
||||
"""Load a checkpoint by name, supporting extra folder paths
|
||||
|
||||
Args:
|
||||
ckpt_name: The name of the checkpoint to load (relative path with extension)
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL, CLIP, VAE)
|
||||
"""
|
||||
# Get absolute path from cache using ComfyUI-style name
|
||||
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
|
||||
|
||||
if metadata is None:
|
||||
raise FileNotFoundError(
|
||||
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
|
||||
"Make sure the checkpoint is indexed and try again."
|
||||
)
|
||||
|
||||
# Load regular checkpoint using ComfyUI's API
|
||||
logger.info(f"Loading checkpoint from: {ckpt_path}")
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
)
|
||||
return out[:3]
|
||||
161
py/nodes/gguf_import_helper.py
Normal file
161
py/nodes/gguf_import_helper.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Helper module to safely import ComfyUI-GGUF modules.
|
||||
|
||||
This module provides a robust way to import ComfyUI-GGUF functionality
|
||||
regardless of how ComfyUI loaded it.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Optional, Tuple, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_gguf_path() -> str:
|
||||
"""Get the path to ComfyUI-GGUF based on this file's location.
|
||||
|
||||
Since ComfyUI-Lora-Manager and ComfyUI-GGUF are both in custom_nodes/,
|
||||
we can derive the GGUF path from our own location.
|
||||
"""
|
||||
# This file is at: custom_nodes/ComfyUI-Lora-Manager/py/nodes/gguf_import_helper.py
|
||||
# ComfyUI-GGUF is at: custom_nodes/ComfyUI-GGUF
|
||||
current_file = os.path.abspath(__file__)
|
||||
# Go up 4 levels: nodes -> py -> ComfyUI-Lora-Manager -> custom_nodes
|
||||
custom_nodes_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
||||
)
|
||||
return os.path.join(custom_nodes_dir, "ComfyUI-GGUF")
|
||||
|
||||
|
||||
def _find_gguf_module() -> Optional[Any]:
|
||||
"""Find ComfyUI-GGUF module in sys.modules.
|
||||
|
||||
ComfyUI registers modules using the full path with dots replaced by _x_.
|
||||
"""
|
||||
gguf_path = _get_gguf_path()
|
||||
sys_module_name = gguf_path.replace(".", "_x_")
|
||||
|
||||
logger.debug(f"[GGUF Import] Looking for module '{sys_module_name}' in sys.modules")
|
||||
if sys_module_name in sys.modules:
|
||||
logger.info(f"[GGUF Import] Found module: '{sys_module_name}'")
|
||||
return sys.modules[sys_module_name]
|
||||
|
||||
logger.debug(f"[GGUF Import] Module not found: '{sys_module_name}'")
|
||||
return None
|
||||
|
||||
|
||||
def _load_gguf_modules_directly() -> Optional[Any]:
|
||||
"""Load ComfyUI-GGUF modules directly from file paths."""
|
||||
gguf_path = _get_gguf_path()
|
||||
|
||||
logger.info(f"[GGUF Import] Direct Load: Attempting to load from '{gguf_path}'")
|
||||
|
||||
if not os.path.exists(gguf_path):
|
||||
logger.warning(f"[GGUF Import] Path does not exist: {gguf_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
namespace = "ComfyUI_GGUF_Dynamic"
|
||||
init_path = os.path.join(gguf_path, "__init__.py")
|
||||
|
||||
if not os.path.exists(init_path):
|
||||
logger.warning(f"[GGUF Import] __init__.py not found at '{init_path}'")
|
||||
return None
|
||||
|
||||
logger.debug(f"[GGUF Import] Loading from '{init_path}'")
|
||||
spec = importlib.util.spec_from_file_location(namespace, init_path)
|
||||
if not spec or not spec.loader:
|
||||
logger.error(f"[GGUF Import] Failed to create spec for '{init_path}'")
|
||||
return None
|
||||
|
||||
package = importlib.util.module_from_spec(spec)
|
||||
package.__path__ = [gguf_path]
|
||||
sys.modules[namespace] = package
|
||||
spec.loader.exec_module(package)
|
||||
logger.debug(f"[GGUF Import] Loaded main package '{namespace}'")
|
||||
|
||||
# Load submodules
|
||||
loaded = []
|
||||
for submod_name in ["loader", "ops", "nodes"]:
|
||||
submod_path = os.path.join(gguf_path, f"{submod_name}.py")
|
||||
if os.path.exists(submod_path):
|
||||
submod_spec = importlib.util.spec_from_file_location(
|
||||
f"{namespace}.{submod_name}", submod_path
|
||||
)
|
||||
if submod_spec and submod_spec.loader:
|
||||
submod = importlib.util.module_from_spec(submod_spec)
|
||||
submod.__package__ = namespace
|
||||
sys.modules[f"{namespace}.{submod_name}"] = submod
|
||||
submod_spec.loader.exec_module(submod)
|
||||
setattr(package, submod_name, submod)
|
||||
loaded.append(submod_name)
|
||||
logger.debug(f"[GGUF Import] Loaded submodule '{submod_name}'")
|
||||
|
||||
logger.info(f"[GGUF Import] Direct Load success: {loaded}")
|
||||
return package
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GGUF Import] Direct Load failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def get_gguf_modules() -> Tuple[Any, Any, Any]:
|
||||
"""Get ComfyUI-GGUF modules (loader, ops, nodes).
|
||||
|
||||
Returns:
|
||||
Tuple of (loader_module, ops_module, nodes_module)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ComfyUI-GGUF cannot be found or loaded.
|
||||
"""
|
||||
logger.debug("[GGUF Import] Starting module search...")
|
||||
|
||||
# Try to find already loaded module first
|
||||
gguf_module = _find_gguf_module()
|
||||
|
||||
if gguf_module is None:
|
||||
logger.info("[GGUF Import] Not found in sys.modules, trying direct load...")
|
||||
gguf_module = _load_gguf_modules_directly()
|
||||
|
||||
if gguf_module is None:
|
||||
raise RuntimeError(
|
||||
"ComfyUI-GGUF is not installed. "
|
||||
"Please install from https://github.com/city96/ComfyUI-GGUF"
|
||||
)
|
||||
|
||||
# Extract submodules
|
||||
loader = getattr(gguf_module, "loader", None)
|
||||
ops = getattr(gguf_module, "ops", None)
|
||||
nodes = getattr(gguf_module, "nodes", None)
|
||||
|
||||
if loader is None or ops is None or nodes is None:
|
||||
missing = [
|
||||
name
|
||||
for name, mod in [("loader", loader), ("ops", ops), ("nodes", nodes)]
|
||||
if mod is None
|
||||
]
|
||||
raise RuntimeError(f"ComfyUI-GGUF missing submodules: {missing}")
|
||||
|
||||
logger.debug("[GGUF Import] All modules loaded successfully")
|
||||
return loader, ops, nodes
|
||||
|
||||
|
||||
def get_gguf_sd_loader():
|
||||
"""Get the gguf_sd_loader function from ComfyUI-GGUF."""
|
||||
loader, _, _ = get_gguf_modules()
|
||||
return getattr(loader, "gguf_sd_loader")
|
||||
|
||||
|
||||
def get_ggml_ops():
|
||||
"""Get the GGMLOps class from ComfyUI-GGUF."""
|
||||
_, ops, _ = get_gguf_modules()
|
||||
return getattr(ops, "GGMLOps")
|
||||
|
||||
|
||||
def get_gguf_model_patcher():
|
||||
"""Get the GGUFModelPatcher class from ComfyUI-GGUF."""
|
||||
_, _, nodes = get_gguf_modules()
|
||||
return getattr(nodes, "GGUFModelPatcher")
|
||||
@@ -8,6 +8,7 @@ and tracks the cycle progress which persists across workflow save/load.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,8 +55,14 @@ class LoraCyclerLM:
|
||||
current_index = cycler_config.get("current_index", 1) # 1-based
|
||||
model_strength = float(cycler_config.get("model_strength", 1.0))
|
||||
clip_strength = float(cycler_config.get("clip_strength", 1.0))
|
||||
use_same_clip_strength = cycler_config.get("use_same_clip_strength", True)
|
||||
use_preset_strength = cycler_config.get("use_preset_strength", False)
|
||||
preset_strength_scale = float(cycler_config.get("preset_strength_scale", 1.0))
|
||||
sort_by = "filename"
|
||||
|
||||
# Include "no lora" option
|
||||
include_no_lora = cycler_config.get("include_no_lora", False)
|
||||
|
||||
# Dual-index mechanism for batch queue synchronization
|
||||
execution_index = cycler_config.get("execution_index") # Can be None
|
||||
# next_index_from_config = cycler_config.get("next_index") # Not used on backend
|
||||
@@ -71,7 +78,10 @@ class LoraCyclerLM:
|
||||
|
||||
total_count = len(lora_list)
|
||||
|
||||
if total_count == 0:
|
||||
# Calculate effective total count (includes no lora option if enabled)
|
||||
effective_total_count = total_count + 1 if include_no_lora else total_count
|
||||
|
||||
if total_count == 0 and not include_no_lora:
|
||||
logger.warning("[LoraCyclerLM] No LoRAs available in pool")
|
||||
return {
|
||||
"result": ([],),
|
||||
@@ -93,42 +103,99 @@ class LoraCyclerLM:
|
||||
else:
|
||||
actual_index = current_index
|
||||
|
||||
# Clamp index to valid range (1-based)
|
||||
clamped_index = max(1, min(actual_index, total_count))
|
||||
# Clamp index to valid range (1-based, includes no lora if enabled)
|
||||
clamped_index = max(1, min(actual_index, effective_total_count))
|
||||
|
||||
# Get LoRA at current index (convert to 0-based for list access)
|
||||
current_lora = lora_list[clamped_index - 1]
|
||||
# Check if current index is the "no lora" option (last position when include_no_lora is True)
|
||||
is_no_lora = include_no_lora and clamped_index == effective_total_count
|
||||
|
||||
# Build LORA_STACK with single LoRA
|
||||
lora_path, _ = get_lora_info(current_lora["file_name"])
|
||||
if not lora_path:
|
||||
logger.warning(
|
||||
f"[LoraCyclerLM] Could not find path for LoRA: {current_lora['file_name']}"
|
||||
)
|
||||
if is_no_lora:
|
||||
# "No LoRA" option - return empty stack
|
||||
lora_stack = []
|
||||
current_lora_name = "No LoRA"
|
||||
current_lora_filename = "No LoRA"
|
||||
else:
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
lora_stack = [(lora_path, model_strength, clip_strength)]
|
||||
# Get LoRA at current index (convert to 0-based for list access)
|
||||
current_lora = lora_list[clamped_index - 1]
|
||||
current_lora_name = current_lora["file_name"]
|
||||
current_lora_filename = current_lora["file_name"]
|
||||
|
||||
# Build LORA_STACK with single LoRA
|
||||
if current_lora["file_name"] == "None":
|
||||
lora_path = None
|
||||
else:
|
||||
lora_path, _ = get_lora_info(current_lora["file_name"])
|
||||
|
||||
if not lora_path:
|
||||
if current_lora["file_name"] != "None":
|
||||
logger.warning(
|
||||
f"[LoraCyclerLM] Could not find path for LoRA: {current_lora['file_name']}"
|
||||
)
|
||||
lora_stack = []
|
||||
else:
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
|
||||
if use_preset_strength:
|
||||
lora_metadata = await lora_service.get_lora_metadata_by_filename(
|
||||
current_lora["file_name"]
|
||||
)
|
||||
if lora_metadata:
|
||||
recommended_strength = (
|
||||
lora_service.get_recommended_strength_from_lora_data(
|
||||
lora_metadata
|
||||
)
|
||||
)
|
||||
if recommended_strength is not None:
|
||||
model_strength = round(
|
||||
recommended_strength * preset_strength_scale, 2
|
||||
)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_strength = model_strength
|
||||
else:
|
||||
recommended_clip_strength = (
|
||||
lora_service.get_recommended_clip_strength_from_lora_data(
|
||||
lora_metadata
|
||||
)
|
||||
)
|
||||
if recommended_clip_strength is not None:
|
||||
clip_strength = round(
|
||||
recommended_clip_strength * preset_strength_scale, 2
|
||||
)
|
||||
elif use_same_clip_strength:
|
||||
clip_strength = model_strength
|
||||
elif use_same_clip_strength:
|
||||
clip_strength = model_strength
|
||||
|
||||
lora_stack = [(lora_path, model_strength, clip_strength)]
|
||||
|
||||
# Calculate next index (wrap to 1 if at end)
|
||||
next_index = clamped_index + 1
|
||||
if next_index > total_count:
|
||||
if next_index > effective_total_count:
|
||||
next_index = 1
|
||||
|
||||
# Get next LoRA for UI display (what will be used next generation)
|
||||
next_lora = lora_list[next_index - 1]
|
||||
next_display_name = next_lora["file_name"]
|
||||
is_next_no_lora = include_no_lora and next_index == effective_total_count
|
||||
if is_next_no_lora:
|
||||
next_display_name = "No LoRA"
|
||||
next_lora_filename = "No LoRA"
|
||||
else:
|
||||
next_lora = lora_list[next_index - 1]
|
||||
next_display_name = next_lora["file_name"]
|
||||
next_lora_filename = next_lora["file_name"]
|
||||
|
||||
return {
|
||||
"result": (lora_stack,),
|
||||
"ui": {
|
||||
"current_index": [clamped_index],
|
||||
"next_index": [next_index],
|
||||
"total_count": [total_count],
|
||||
"current_lora_name": [current_lora["file_name"]],
|
||||
"current_lora_filename": [current_lora["file_name"]],
|
||||
"total_count": [
|
||||
total_count
|
||||
], # Return actual LoRA count, not effective_total_count
|
||||
"current_lora_name": [current_lora_name],
|
||||
"current_lora_filename": [current_lora_filename],
|
||||
"next_lora_name": [next_display_name],
|
||||
"next_lora_filename": [next_lora["file_name"]],
|
||||
"next_lora_filename": [next_lora_filename],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,22 +1,139 @@
|
||||
import importlib
|
||||
import logging
|
||||
import re
|
||||
import comfy.utils # type: ignore
|
||||
import comfy.sd # type: ignore
|
||||
|
||||
import comfy.sd # type: ignore
|
||||
import comfy.utils # type: ignore
|
||||
|
||||
from ..utils.utils import get_lora_info_absolute
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
from .utils import (
|
||||
FlexibleOptionalInputType,
|
||||
any_type,
|
||||
apply_lora_syntax_format,
|
||||
detect_nunchaku_model_kind,
|
||||
extract_lora_name,
|
||||
get_loras_list,
|
||||
nunchaku_load_lora,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_nunchaku_load_qwen_loras():
|
||||
try:
|
||||
module = importlib.import_module(".nunchaku_qwen", __package__)
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"Qwen-Image LoRA loading requires the ComfyUI runtime with its torch dependency available."
|
||||
) from exc
|
||||
return module.nunchaku_load_qwen_loras
|
||||
|
||||
|
||||
def _collect_stack_entries(lora_stack):
|
||||
entries = []
|
||||
if not lora_stack:
|
||||
return entries
|
||||
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
entries.append({
|
||||
"name": lora_name,
|
||||
"absolute_path": absolute_lora_path,
|
||||
"input_path": lora_path,
|
||||
"model_strength": float(model_strength),
|
||||
"clip_strength": float(clip_strength),
|
||||
"trigger_words": trigger_words,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
def _collect_widget_entries(kwargs):
|
||||
entries = []
|
||||
for lora in get_loras_list(kwargs):
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
lora_name = apply_lora_syntax_format(lora["name"])
|
||||
model_strength = float(lora["strength"])
|
||||
clip_strength = float(lora.get("clipStrength", model_strength))
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
entries.append({
|
||||
"name": lora_name,
|
||||
"absolute_path": lora_path,
|
||||
"input_path": lora_path,
|
||||
"model_strength": model_strength,
|
||||
"clip_strength": clip_strength,
|
||||
"trigger_words": trigger_words,
|
||||
})
|
||||
return entries
|
||||
|
||||
|
||||
def _format_loaded_loras(loaded_loras):
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
if item["include_clip_strength"]:
|
||||
formatted_loras.append(
|
||||
f"<lora:{item['name']}:{item['model_strength']}:{item['clip_strength']}>"
|
||||
)
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{item['name']}:{item['model_strength']}>")
|
||||
return " ".join(formatted_loras)
|
||||
|
||||
|
||||
def _apply_entries(model, clip, lora_entries, nunchaku_model_kind):
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
if nunchaku_model_kind == "qwen_image":
|
||||
nunchaku_load_qwen_loras = _get_nunchaku_load_qwen_loras()
|
||||
qwen_lora_configs = []
|
||||
for entry in lora_entries:
|
||||
qwen_lora_configs.append((entry["absolute_path"], entry["model_strength"]))
|
||||
loaded_loras.append({
|
||||
"name": entry["name"],
|
||||
"model_strength": entry["model_strength"],
|
||||
"clip_strength": entry["model_strength"],
|
||||
"include_clip_strength": False,
|
||||
})
|
||||
all_trigger_words.extend(entry["trigger_words"])
|
||||
if qwen_lora_configs:
|
||||
model = nunchaku_load_qwen_loras(model, qwen_lora_configs)
|
||||
return model, clip, loaded_loras, all_trigger_words
|
||||
|
||||
for entry in lora_entries:
|
||||
if nunchaku_model_kind == "flux":
|
||||
model = nunchaku_load_lora(model, entry["input_path"], entry["model_strength"])
|
||||
else:
|
||||
lora = comfy.utils.load_torch_file(entry["absolute_path"], safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(
|
||||
model,
|
||||
clip,
|
||||
lora,
|
||||
entry["model_strength"],
|
||||
entry["clip_strength"],
|
||||
)
|
||||
|
||||
include_clip_strength = nunchaku_model_kind is None and abs(entry["model_strength"] - entry["clip_strength"]) > 0.001
|
||||
loaded_loras.append({
|
||||
"name": entry["name"],
|
||||
"model_strength": entry["model_strength"],
|
||||
"clip_strength": entry["clip_strength"],
|
||||
"include_clip_strength": include_clip_strength,
|
||||
})
|
||||
all_trigger_words.extend(entry["trigger_words"])
|
||||
|
||||
return model, clip, loaded_loras, all_trigger_words
|
||||
|
||||
|
||||
class LoraLoaderLM:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
# "clip": ("CLIP",),
|
||||
"text": ("AUTOCOMPLETE_TEXT_LORAS", {
|
||||
"placeholder": "Search LoRAs to add...",
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
@@ -28,114 +145,30 @@ class LoraLoaderLM:
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras"
|
||||
|
||||
|
||||
def load_loras(self, model, text, **kwargs):
|
||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name and convert to absolute path
|
||||
# lora_stack stores relative paths, but load_torch_file needs absolute paths
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Then process loras from kwargs with support for both old and new formats
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0]
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
del text
|
||||
clip = kwargs.get("clip", None)
|
||||
lora_entries = _collect_stack_entries(kwargs.get("lora_stack", None))
|
||||
lora_entries.extend(_collect_widget_entries(kwargs))
|
||||
|
||||
nunchaku_model_kind = detect_nunchaku_model_kind(model)
|
||||
if nunchaku_model_kind == "flux":
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
elif nunchaku_model_kind == "qwen_image":
|
||||
logger.info("Detected Nunchaku Qwen-Image model")
|
||||
|
||||
model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind)
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
formatted_loras_text = _format_loaded_loras(loaded_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
|
||||
class LoraTextLoaderLM:
|
||||
NAME = "LoRA Text Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@@ -143,131 +176,55 @@ class LoraTextLoaderLM:
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": ("STRING", {
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
"""Parse LoRA syntax from text input."""
|
||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
||||
pattern = r"<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>"
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
|
||||
|
||||
loras = []
|
||||
for match in matches:
|
||||
lora_name = match[0]
|
||||
model_strength = float(match[1])
|
||||
clip_strength = float(match[2]) if match[2] else model_strength
|
||||
|
||||
loras.append({
|
||||
'name': lora_name,
|
||||
'model_strength': model_strength,
|
||||
'clip_strength': clip_strength
|
||||
"name": match[0],
|
||||
"model_strength": model_strength,
|
||||
"clip_strength": float(match[2]) if match[2] else model_strength,
|
||||
})
|
||||
|
||||
return loras
|
||||
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
"""Load LoRAs based on text syntax input."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name and convert to absolute path
|
||||
# lora_stack stores relative paths, but load_torch_file needs absolute paths
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
absolute_lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(absolute_lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Parse and process LoRAs from text syntax
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora['name']
|
||||
model_strength = lora['model_strength']
|
||||
clip_strength = lora['clip_strength']
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use lower-level API to load LoRA directly without folder_paths validation
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, lora, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
lora_entries = _collect_stack_entries(lora_stack)
|
||||
for lora in self.parse_lora_syntax(lora_syntax):
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora["name"])
|
||||
lora_entries.append({
|
||||
"name": lora["name"],
|
||||
"absolute_path": lora_path,
|
||||
"input_path": lora_path,
|
||||
"model_strength": lora["model_strength"],
|
||||
"clip_strength": lora["clip_strength"],
|
||||
"trigger_words": trigger_words,
|
||||
})
|
||||
|
||||
nunchaku_model_kind = detect_nunchaku_model_kind(model)
|
||||
if nunchaku_model_kind == "flux":
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
elif nunchaku_model_kind == "qwen_image":
|
||||
logger.info("Detected Nunchaku Qwen-Image model")
|
||||
|
||||
model, clip, loaded_loras, all_trigger_words = _apply_entries(model, clip, lora_entries, nunchaku_model_kind)
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0].strip()
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
formatted_loras_text = _format_loaded_loras(loaded_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
@@ -82,6 +82,7 @@ class LoraPoolLM:
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"favoritesOnly": False,
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": [], "exclude": [], "useRegex": False},
|
||||
},
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@ and tracks the last used combination for reuse.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import extract_lora_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
26
py/nodes/lora_stack_combiner.py
Normal file
26
py/nodes/lora_stack_combiner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
class LoraStackCombinerLM:
|
||||
NAME = "Lora Stack Combiner (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"lora_stack_a": ("LORA_STACK",),
|
||||
"lora_stack_b": ("LORA_STACK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK",)
|
||||
RETURN_NAMES = ("LORA_STACK",)
|
||||
FUNCTION = "combine_stacks"
|
||||
|
||||
def combine_stacks(self, lora_stack_a, lora_stack_b):
|
||||
combined_stack = []
|
||||
|
||||
if lora_stack_a:
|
||||
combined_stack.extend(lora_stack_a)
|
||||
if lora_stack_b:
|
||||
combined_stack.extend(lora_stack_b)
|
||||
|
||||
return (combined_stack,)
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
from .utils import FlexibleOptionalInputType, any_type, apply_lora_syntax_format, extract_lora_name, get_loras_list
|
||||
|
||||
import logging
|
||||
|
||||
@@ -48,7 +48,7 @@ class LoraStackerLM:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
lora_name = apply_lora_syntax_format(lora['name'])
|
||||
model_strength = float(lora['strength'])
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
570
py/nodes/nunchaku_qwen.py
Normal file
570
py/nodes/nunchaku_qwen.py
Normal file
@@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Qwen-Image LoRA support for Nunchaku models.
|
||||
|
||||
Portions of the LoRA mapping/application logic in this file are adapted from
|
||||
ComfyUI-QwenImageLoraLoader by GitHub user ussoewwin:
|
||||
https://github.com/ussoewwin/ComfyUI-QwenImageLoraLoader
|
||||
|
||||
The upstream project is licensed under Apache License 2.0.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import comfy.utils # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors import safe_open
|
||||
|
||||
from nunchaku.lora.flux.nunchaku_converter import (
|
||||
pack_lowrank_weight,
|
||||
unpack_lowrank_weight,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KEY_MAPPING = [
|
||||
(re.compile(r"^(layers)[._](\d+)[._]attention[._]to[._]([qkv])$"), r"\1.\2.attention.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._](w1|w3)$"), r"\1.\2.feed_forward.net.0.proj", "glu", lambda m: m.group(3)),
|
||||
(re.compile(r"^(layers)[._](\d+)[._]feed_forward[._]w2$"), r"\1.\2.feed_forward.net.2", "regular", None),
|
||||
(re.compile(r"^(layers)[._](\d+)[._](.*)$"), r"\1.\2.\3", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._](q|k|v)[._]proj$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]add[._](q|k|v)[._]proj$"), r"\1.\2.attn.add_qkv_proj", "add_qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj[._]context$"), r"\1.\2.attn.to_add_out", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]out[._]proj$"), r"\1.\2.attn.to_out.0", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out.0", "regular", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]([qkv])$"), r"\1.\2.attn.to_qkv", "qkv", lambda m: m.group(3).upper()),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]attn[._]to[._]out$"), r"\1.\2.attn.to_out", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_fc1", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff[._]net[._]2$"), r"\1.\2.mlp_fc2", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]0(?:[._]proj)?$"), r"\1.\2.mlp_context_fc1", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]ff_context[._]net[._]2$"), r"\1.\2.mlp_context_fc2", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](0)[._](proj)$"), r"\1.\2.\3.\4.\5.\6", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mlp)[._](net)[._](2)$"), r"\1.\2.\3.\4.\5", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](img_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._](txt_mod)[._](1)$"), r"\1.\2.\3.\4", "regular", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]out$"), r"\1.\2.proj_out", "single_proj_out", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]proj[._]mlp$"), r"\1.\2.mlp_fc1", "regular", None),
|
||||
(re.compile(r"^(single_transformer_blocks)[._](\d+)[._]norm[._]linear$"), r"\1.\2.norm.linear", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1[._]linear$"), r"\1.\2.norm1.linear", "regular", None),
|
||||
(re.compile(r"^(transformer_blocks)[._](\d+)[._]norm1_context[._]linear$"), r"\1.\2.norm1_context.linear", "regular", None),
|
||||
(re.compile(r"^(img_in)$"), r"\1", "regular", None),
|
||||
(re.compile(r"^(txt_in)$"), r"\1", "regular", None),
|
||||
(re.compile(r"^(proj_out)$"), r"\1", "regular", None),
|
||||
(re.compile(r"^(norm_out)[._](linear)$"), r"\1.\2", "regular", None),
|
||||
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_1)$"), r"\1.\2.\3", "regular", None),
|
||||
(re.compile(r"^(time_text_embed)[._](timestep_embedder)[._](linear_2)$"), r"\1.\2.\3", "regular", None),
|
||||
]
|
||||
|
||||
_RE_LORA_SUFFIX = re.compile(r"\.(?P<tag>lora(?:[._](?:A|B|down|up)))(?:\.[^.]+)*\.weight$")
|
||||
_RE_ALPHA_SUFFIX = re.compile(r"\.(?:alpha|lora_alpha)(?:\.[^.]+)*$")
|
||||
|
||||
|
||||
def _rename_layer_underscore_layer_name(old_name: str) -> str:
|
||||
rules = [
|
||||
(r"_(\d+)_attn_to_out_(\d+)", r".\1.attn.to_out.\2"),
|
||||
(r"_(\d+)_img_mlp_net_(\d+)_proj", r".\1.img_mlp.net.\2.proj"),
|
||||
(r"_(\d+)_txt_mlp_net_(\d+)_proj", r".\1.txt_mlp.net.\2.proj"),
|
||||
(r"_(\d+)_img_mlp_net_(\d+)", r".\1.img_mlp.net.\2"),
|
||||
(r"_(\d+)_txt_mlp_net_(\d+)", r".\1.txt_mlp.net.\2"),
|
||||
(r"_(\d+)_img_mod_(\d+)", r".\1.img_mod.\2"),
|
||||
(r"_(\d+)_txt_mod_(\d+)", r".\1.txt_mod.\2"),
|
||||
(r"_(\d+)_attn_", r".\1.attn."),
|
||||
]
|
||||
new_name = old_name
|
||||
for pattern, replacement in rules:
|
||||
new_name = re.sub(pattern, replacement, new_name)
|
||||
return new_name
|
||||
|
||||
|
||||
def _is_indexable_module(module):
|
||||
return isinstance(module, (nn.ModuleList, nn.Sequential, list, tuple))
|
||||
|
||||
|
||||
def _get_module_by_name(model: nn.Module, name: str) -> Optional[nn.Module]:
|
||||
if not name:
|
||||
return model
|
||||
module = model
|
||||
for part in name.split("."):
|
||||
if not part:
|
||||
continue
|
||||
if hasattr(module, part):
|
||||
module = getattr(module, part)
|
||||
elif part.isdigit() and _is_indexable_module(module):
|
||||
try:
|
||||
module = module[int(part)]
|
||||
except (IndexError, TypeError):
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return module
|
||||
|
||||
|
||||
def _resolve_module_name(model: nn.Module, name: str) -> Tuple[str, Optional[nn.Module]]:
|
||||
module = _get_module_by_name(model, name)
|
||||
if module is not None:
|
||||
return name, module
|
||||
|
||||
replacements = [
|
||||
(".attn.to_out.0", ".attn.to_out"),
|
||||
(".attention.to_qkv", ".attention.qkv"),
|
||||
(".attention.to_out.0", ".attention.out"),
|
||||
(".feed_forward.net.0.proj", ".feed_forward.w13"),
|
||||
(".feed_forward.net.2", ".feed_forward.w2"),
|
||||
(".ff.net.0.proj", ".mlp_fc1"),
|
||||
(".ff.net.2", ".mlp_fc2"),
|
||||
(".ff_context.net.0.proj", ".mlp_context_fc1"),
|
||||
(".ff_context.net.2", ".mlp_context_fc2"),
|
||||
]
|
||||
for src, dst in replacements:
|
||||
if src in name:
|
||||
alt = name.replace(src, dst)
|
||||
module = _get_module_by_name(model, alt)
|
||||
if module is not None:
|
||||
return alt, module
|
||||
return name, None
|
||||
|
||||
|
||||
def _classify_and_map_key(key: str) -> Optional[Tuple[str, str, Optional[str], str]]:
|
||||
normalized = key
|
||||
if normalized.startswith("transformer."):
|
||||
normalized = normalized[len("transformer."):]
|
||||
if normalized.startswith("diffusion_model."):
|
||||
normalized = normalized[len("diffusion_model."):]
|
||||
if normalized.startswith("lora_unet_"):
|
||||
normalized = _rename_layer_underscore_layer_name(normalized[len("lora_unet_"):])
|
||||
|
||||
match = _RE_LORA_SUFFIX.search(normalized)
|
||||
if match:
|
||||
tag = match.group("tag")
|
||||
base = normalized[:match.start()]
|
||||
ab = "A" if ("lora_A" in tag or tag.endswith(".A") or "down" in tag) else "B"
|
||||
else:
|
||||
match = _RE_ALPHA_SUFFIX.search(normalized)
|
||||
if not match:
|
||||
return None
|
||||
base = normalized[:match.start()]
|
||||
ab = "alpha"
|
||||
|
||||
for pattern, template, group, comp_fn in KEY_MAPPING:
|
||||
key_match = pattern.match(base)
|
||||
if key_match:
|
||||
return group, key_match.expand(template), comp_fn(key_match) if comp_fn else None, ab
|
||||
return None
|
||||
|
||||
|
||||
def _detect_lora_format(lora_state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
standard_patterns = (
|
||||
".lora_up.",
|
||||
".lora_down.",
|
||||
".lora_A.",
|
||||
".lora_B.",
|
||||
".lora.up.",
|
||||
".lora.down.",
|
||||
".lora.A.",
|
||||
".lora.B.",
|
||||
)
|
||||
return any(pattern in key for key in lora_state_dict for pattern in standard_patterns)
|
||||
|
||||
|
||||
def _load_lora_state_dict(path_or_dict: Union[str, Path, Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(path_or_dict, dict):
|
||||
return path_or_dict
|
||||
path = Path(path_or_dict)
|
||||
if path.suffix == ".safetensors":
|
||||
state_dict: Dict[str, torch.Tensor] = {}
|
||||
with safe_open(path, framework="pt", device="cpu") as handle:
|
||||
for key in handle.keys():
|
||||
state_dict[key] = handle.get_tensor(key)
|
||||
return state_dict
|
||||
return comfy.utils.load_torch_file(str(path), safe_load=True)
|
||||
|
||||
|
||||
def _fuse_glu_lora(glu_weights: Dict[str, torch.Tensor]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if "w1_A" not in glu_weights or "w3_A" not in glu_weights:
|
||||
return None, None, None
|
||||
a_w1, b_w1 = glu_weights["w1_A"], glu_weights["w1_B"]
|
||||
a_w3, b_w3 = glu_weights["w3_A"], glu_weights["w3_B"]
|
||||
if a_w1.shape[1] != a_w3.shape[1]:
|
||||
return None, None, None
|
||||
a_fused = torch.cat([a_w1, a_w3], dim=0)
|
||||
out1, out3 = b_w1.shape[0], b_w3.shape[0]
|
||||
rank1, rank3 = b_w1.shape[1], b_w3.shape[1]
|
||||
b_fused = torch.zeros(out1 + out3, rank1 + rank3, dtype=b_w1.dtype, device=b_w1.device)
|
||||
b_fused[:out1, :rank1] = b_w1
|
||||
b_fused[out1:, rank1:] = b_w3
|
||||
return a_fused, b_fused, glu_weights.get("w1_alpha")
|
||||
|
||||
|
||||
def _fuse_qkv_lora(qkv_weights: Dict[str, torch.Tensor], model: Optional[nn.Module] = None, base_key: Optional[str] = None) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
required_keys = ["Q_A", "Q_B", "K_A", "K_B", "V_A", "V_B"]
|
||||
if not all(key in qkv_weights for key in required_keys):
|
||||
return None, None, None
|
||||
a_q, a_k, a_v = qkv_weights["Q_A"], qkv_weights["K_A"], qkv_weights["V_A"]
|
||||
b_q, b_k, b_v = qkv_weights["Q_B"], qkv_weights["K_B"], qkv_weights["V_B"]
|
||||
if not (a_q.shape == a_k.shape == a_v.shape):
|
||||
return None, None, None
|
||||
if not (b_q.shape[1] == b_k.shape[1] == b_v.shape[1]):
|
||||
return None, None, None
|
||||
|
||||
out_features = None
|
||||
if model is not None and base_key is not None:
|
||||
_, module = _resolve_module_name(model, base_key)
|
||||
out_features = getattr(module, "out_features", None) if module is not None else None
|
||||
|
||||
alpha_fused = None
|
||||
alpha_q = qkv_weights.get("Q_alpha")
|
||||
alpha_k = qkv_weights.get("K_alpha")
|
||||
alpha_v = qkv_weights.get("V_alpha")
|
||||
if alpha_q is not None and alpha_k is not None and alpha_v is not None and alpha_q.item() == alpha_k.item() == alpha_v.item():
|
||||
alpha_fused = alpha_q
|
||||
|
||||
a_fused = torch.cat([a_q, a_k, a_v], dim=0)
|
||||
rank = b_q.shape[1]
|
||||
out_q, out_k, out_v = b_q.shape[0], b_k.shape[0], b_v.shape[0]
|
||||
total_out = out_features if out_features is not None else out_q + out_k + out_v
|
||||
b_fused = torch.zeros(total_out, 3 * rank, dtype=b_q.dtype, device=b_q.device)
|
||||
b_fused[:out_q, :rank] = b_q
|
||||
b_fused[out_q:out_q + out_k, rank:2 * rank] = b_k
|
||||
b_fused[out_q + out_k:out_q + out_k + out_v, 2 * rank:] = b_v
|
||||
return a_fused, b_fused, alpha_fused
|
||||
|
||||
|
||||
def _handle_proj_out_split(lora_dict: Dict[str, Dict[str, torch.Tensor]], base_key: str, model: nn.Module) -> Tuple[Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], List[str]]:
|
||||
result: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
|
||||
consumed: List[str] = []
|
||||
match = re.search(r"single_transformer_blocks\.(\d+)", base_key)
|
||||
if not match or base_key not in lora_dict:
|
||||
return result, consumed
|
||||
block_idx = match.group(1)
|
||||
block = _get_module_by_name(model, f"single_transformer_blocks.{block_idx}")
|
||||
if block is None:
|
||||
return result, consumed
|
||||
a_full = lora_dict[base_key].get("A")
|
||||
b_full = lora_dict[base_key].get("B")
|
||||
alpha = lora_dict[base_key].get("alpha")
|
||||
attn_to_out = getattr(getattr(block, "attn", None), "to_out", None)
|
||||
mlp_fc2 = getattr(block, "mlp_fc2", None)
|
||||
if a_full is None or b_full is None or attn_to_out is None or mlp_fc2 is None:
|
||||
return result, consumed
|
||||
attn_in = getattr(attn_to_out, "in_features", None)
|
||||
mlp_in = getattr(mlp_fc2, "in_features", None)
|
||||
if attn_in is None or mlp_in is None or a_full.shape[1] != attn_in + mlp_in:
|
||||
return result, consumed
|
||||
result[f"single_transformer_blocks.{block_idx}.attn.to_out"] = (a_full[:, :attn_in], b_full.clone(), alpha)
|
||||
result[f"single_transformer_blocks.{block_idx}.mlp_fc2"] = (a_full[:, attn_in:], b_full.clone(), alpha)
|
||||
consumed.append(base_key)
|
||||
return result, consumed
|
||||
|
||||
|
||||
def _apply_lora_to_module(module: nn.Module, a_tensor: torch.Tensor, b_tensor: torch.Tensor, module_name: str, model: nn.Module) -> None:
|
||||
if not hasattr(module, "in_features") or not hasattr(module, "out_features"):
|
||||
raise ValueError(f"{module_name}: unsupported module without in/out features")
|
||||
if a_tensor.shape[1] != module.in_features or b_tensor.shape[0] != module.out_features:
|
||||
raise ValueError(f"{module_name}: LoRA shape mismatch")
|
||||
|
||||
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
|
||||
if not hasattr(module, "_lora_original_forward"):
|
||||
module._lora_original_forward = module.forward
|
||||
if not hasattr(module, "_nunchaku_lora_bundle"):
|
||||
module._nunchaku_lora_bundle = []
|
||||
module._nunchaku_lora_bundle.append((a_tensor, b_tensor))
|
||||
|
||||
def _awq_lora_forward(x, *args, **kwargs):
|
||||
out = module._lora_original_forward(x, *args, **kwargs)
|
||||
x_flat = x.reshape(-1, module.in_features)
|
||||
for local_a, local_b in module._nunchaku_lora_bundle:
|
||||
local_a = local_a.to(device=out.device, dtype=out.dtype)
|
||||
local_b = local_b.to(device=out.device, dtype=out.dtype)
|
||||
lora_term = (x_flat @ local_a.transpose(0, 1)) @ local_b.transpose(0, 1)
|
||||
try:
|
||||
out = out + lora_term.reshape(out.shape)
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
module.forward = _awq_lora_forward
|
||||
if not hasattr(model, "_lora_slots"):
|
||||
model._lora_slots = {}
|
||||
model._lora_slots[module_name] = {"type": "awq_w4a16"}
|
||||
return
|
||||
|
||||
if hasattr(module, "proj_down") and hasattr(module, "proj_up"):
|
||||
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
|
||||
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
|
||||
base_rank = proj_down.shape[0] if proj_down.shape[1] == module.in_features else proj_down.shape[1]
|
||||
if proj_down.shape[1] == module.in_features:
|
||||
updated_down = torch.cat([proj_down, a_tensor], dim=0)
|
||||
axis_down = 0
|
||||
else:
|
||||
updated_down = torch.cat([proj_down, a_tensor.T], dim=1)
|
||||
axis_down = 1
|
||||
updated_up = torch.cat([proj_up, b_tensor], dim=1)
|
||||
module.proj_down.data = pack_lowrank_weight(updated_down, down=True)
|
||||
module.proj_up.data = pack_lowrank_weight(updated_up, down=False)
|
||||
module.rank = base_rank + a_tensor.shape[0]
|
||||
if not hasattr(model, "_lora_slots"):
|
||||
model._lora_slots = {}
|
||||
model._lora_slots[module_name] = {
|
||||
"type": "nunchaku",
|
||||
"base_rank": base_rank,
|
||||
"axis_down": axis_down,
|
||||
}
|
||||
return
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
if not hasattr(model, "_lora_slots"):
|
||||
model._lora_slots = {}
|
||||
if module_name not in model._lora_slots:
|
||||
model._lora_slots[module_name] = {
|
||||
"type": "linear",
|
||||
"original_weight": module.weight.detach().cpu().clone(),
|
||||
}
|
||||
module.weight.data.add_((b_tensor @ a_tensor).to(dtype=module.weight.dtype, device=module.weight.device))
|
||||
return
|
||||
|
||||
raise ValueError(f"{module_name}: unsupported module type {type(module)}")
|
||||
|
||||
|
||||
def reset_lora_v2(model: nn.Module) -> None:
|
||||
slots = getattr(model, "_lora_slots", None)
|
||||
if not slots:
|
||||
return
|
||||
for name, info in list(slots.items()):
|
||||
module = _get_module_by_name(model, name)
|
||||
if module is None:
|
||||
continue
|
||||
module_type = info.get("type", "nunchaku")
|
||||
if module_type == "nunchaku":
|
||||
base_rank = info["base_rank"]
|
||||
proj_down = unpack_lowrank_weight(module.proj_down.data, down=True)
|
||||
proj_up = unpack_lowrank_weight(module.proj_up.data, down=False)
|
||||
if info.get("axis_down", 0) == 0:
|
||||
proj_down = proj_down[:base_rank, :].clone()
|
||||
else:
|
||||
proj_down = proj_down[:, :base_rank].clone()
|
||||
proj_up = proj_up[:, :base_rank].clone()
|
||||
module.proj_down.data = pack_lowrank_weight(proj_down, down=True)
|
||||
module.proj_up.data = pack_lowrank_weight(proj_up, down=False)
|
||||
module.rank = base_rank
|
||||
elif module_type == "linear" and "original_weight" in info:
|
||||
module.weight.data.copy_(info["original_weight"].to(device=module.weight.device, dtype=module.weight.dtype))
|
||||
elif module_type == "awq_w4a16":
|
||||
if hasattr(module, "_lora_original_forward"):
|
||||
module.forward = module._lora_original_forward
|
||||
for attr in ("_lora_original_forward", "_nunchaku_lora_bundle"):
|
||||
if hasattr(module, attr):
|
||||
delattr(module, attr)
|
||||
model._lora_slots = {}
|
||||
|
||||
|
||||
def compose_loras_v2(model: nn.Module, lora_configs: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]], apply_awq_mod: bool = True) -> bool:
|
||||
del apply_awq_mod # retained for interface compatibility
|
||||
reset_lora_v2(model)
|
||||
aggregated_weights: Dict[str, List[Dict[str, object]]] = defaultdict(list)
|
||||
saw_supported_format = False
|
||||
unresolved_targets = 0
|
||||
|
||||
for index, (path_or_dict, strength) in enumerate(lora_configs):
|
||||
if abs(strength) < 1e-5:
|
||||
continue
|
||||
lora_name = str(path_or_dict) if not isinstance(path_or_dict, dict) else f"lora_{index}"
|
||||
lora_state_dict = _load_lora_state_dict(path_or_dict)
|
||||
if not lora_state_dict or not _detect_lora_format(lora_state_dict):
|
||||
logger.warning("Skipping unsupported Qwen LoRA: %s", lora_name)
|
||||
continue
|
||||
saw_supported_format = True
|
||||
|
||||
grouped_weights: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
|
||||
for key, value in lora_state_dict.items():
|
||||
parsed = _classify_and_map_key(key)
|
||||
if parsed is None:
|
||||
continue
|
||||
group, base_key, component, ab = parsed
|
||||
if component and ab:
|
||||
grouped_weights[base_key][f"{component}_{ab}"] = value
|
||||
else:
|
||||
grouped_weights[base_key][ab] = value
|
||||
|
||||
processed_groups: Dict[str, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = {}
|
||||
handled: set[str] = set()
|
||||
for base_key, weights in grouped_weights.items():
|
||||
if base_key in handled:
|
||||
continue
|
||||
a_tensor = b_tensor = alpha = None
|
||||
if "qkv" in base_key or "add_qkv_proj" in base_key:
|
||||
a_tensor, b_tensor, alpha = _fuse_qkv_lora(weights, model=model, base_key=base_key)
|
||||
elif "w1_A" in weights or "w3_A" in weights:
|
||||
a_tensor, b_tensor, alpha = _fuse_glu_lora(weights)
|
||||
elif ".proj_out" in base_key and "single_transformer_blocks" in base_key:
|
||||
split_map, consumed = _handle_proj_out_split(grouped_weights, base_key, model)
|
||||
processed_groups.update(split_map)
|
||||
handled.update(consumed)
|
||||
continue
|
||||
else:
|
||||
a_tensor, b_tensor, alpha = weights.get("A"), weights.get("B"), weights.get("alpha")
|
||||
if a_tensor is not None and b_tensor is not None:
|
||||
processed_groups[base_key] = (a_tensor, b_tensor, alpha)
|
||||
|
||||
for module_name, (a_tensor, b_tensor, alpha) in processed_groups.items():
|
||||
aggregated_weights[module_name].append({
|
||||
"A": a_tensor,
|
||||
"B": b_tensor,
|
||||
"alpha": alpha,
|
||||
"strength": strength,
|
||||
})
|
||||
|
||||
for module_name, weight_list in aggregated_weights.items():
|
||||
resolved_name, module = _resolve_module_name(model, module_name)
|
||||
if module is None:
|
||||
logger.warning("Skipping unresolved Qwen LoRA target: %s", module_name)
|
||||
unresolved_targets += 1
|
||||
continue
|
||||
all_a = []
|
||||
all_b_scaled = []
|
||||
for item in weight_list:
|
||||
a_tensor = item["A"]
|
||||
b_tensor = item["B"]
|
||||
alpha = item["alpha"]
|
||||
strength = float(item["strength"])
|
||||
rank = a_tensor.shape[0]
|
||||
scale = strength * ((alpha / rank) if alpha is not None else 1.0)
|
||||
if module.__class__.__name__ == "AWQW4A16Linear" and hasattr(module, "qweight"):
|
||||
target_dtype = torch.float16
|
||||
target_device = module.qweight.device
|
||||
elif hasattr(module, "proj_down"):
|
||||
target_dtype = module.proj_down.dtype
|
||||
target_device = module.proj_down.device
|
||||
elif hasattr(module, "weight"):
|
||||
target_dtype = module.weight.dtype
|
||||
target_device = module.weight.device
|
||||
else:
|
||||
target_dtype = torch.float16
|
||||
target_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
all_a.append(a_tensor.to(dtype=target_dtype, device=target_device))
|
||||
all_b_scaled.append((b_tensor * scale).to(dtype=target_dtype, device=target_device))
|
||||
if not all_a:
|
||||
continue
|
||||
_apply_lora_to_module(module, torch.cat(all_a, dim=0), torch.cat(all_b_scaled, dim=1), resolved_name, model)
|
||||
|
||||
slot_count = len(getattr(model, "_lora_slots", {}) or {})
|
||||
logger.info(
|
||||
"Qwen LoRA composition finished: requested=%d supported=%s applied_targets=%d unresolved=%d",
|
||||
len(lora_configs),
|
||||
saw_supported_format,
|
||||
slot_count,
|
||||
unresolved_targets,
|
||||
)
|
||||
return saw_supported_format
|
||||
|
||||
|
||||
class ComfyQwenImageWrapperLM(nn.Module):
|
||||
def __init__(self, model: nn.Module, config=None, apply_awq_mod: bool = True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = {} if config is None else config
|
||||
self.dtype = next(model.parameters()).dtype
|
||||
self.loras: List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]] = []
|
||||
self._applied_loras: Optional[List[Tuple[Union[str, Path, Dict[str, torch.Tensor]], float]]] = None
|
||||
self.apply_awq_mod = apply_awq_mod
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
inner = object.__getattribute__(self, "_modules").get("model")
|
||||
except (AttributeError, KeyError):
|
||||
inner = None
|
||||
if inner is None:
|
||||
raise AttributeError(f"{type(self).__name__!s} has no attribute {name}")
|
||||
if name == "model":
|
||||
return inner
|
||||
return getattr(inner, name)
|
||||
|
||||
def process_img(self, *args, **kwargs):
|
||||
return self.model.process_img(*args, **kwargs)
|
||||
|
||||
def _ensure_composed(self):
|
||||
if self._applied_loras != self.loras or (not self.loras and getattr(self.model, "_lora_slots", None)):
|
||||
is_supported_format = compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
|
||||
self._applied_loras = self.loras.copy()
|
||||
has_slots = bool(getattr(self.model, "_lora_slots", None))
|
||||
if self.loras and is_supported_format and not has_slots:
|
||||
logger.warning("Qwen LoRA compose produced 0 target modules. Resetting and retrying once.")
|
||||
reset_lora_v2(self.model)
|
||||
compose_loras_v2(self.model, self.loras, apply_awq_mod=self.apply_awq_mod)
|
||||
has_slots = bool(getattr(self.model, "_lora_slots", None))
|
||||
logger.info("Qwen LoRA retry result: applied_targets=%d", len(getattr(self.model, "_lora_slots", {}) or {}))
|
||||
|
||||
offload_manager = getattr(self.model, "offload_manager", None)
|
||||
if offload_manager is not None:
|
||||
offload_settings = {
|
||||
"num_blocks_on_gpu": getattr(offload_manager, "num_blocks_on_gpu", 1),
|
||||
"use_pin_memory": getattr(offload_manager, "use_pin_memory", False),
|
||||
}
|
||||
logger.info(
|
||||
"Rebuilding Qwen offload manager after LoRA compose: num_blocks_on_gpu=%s use_pin_memory=%s",
|
||||
offload_settings["num_blocks_on_gpu"],
|
||||
offload_settings["use_pin_memory"],
|
||||
)
|
||||
self.model.set_offload(False)
|
||||
self.model.set_offload(True, **offload_settings)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
self._ensure_composed()
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_qwen_wrapper_and_transformer(model):
|
||||
model_wrapper = model.model.diffusion_model
|
||||
if hasattr(model_wrapper, "model") and hasattr(model_wrapper, "loras"):
|
||||
transformer = model_wrapper.model
|
||||
if transformer.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
return model_wrapper, transformer
|
||||
if model_wrapper.__class__.__name__.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
wrapped_model = ComfyQwenImageWrapperLM(model_wrapper, getattr(model_wrapper, "config", {}))
|
||||
model.model.diffusion_model = wrapped_model
|
||||
return wrapped_model, wrapped_model.model
|
||||
raise TypeError(f"This LoRA loader only works with Nunchaku Qwen Image models, but got {type(model_wrapper).__name__}.")
|
||||
|
||||
|
||||
def nunchaku_load_qwen_loras(model, lora_configs: List[Tuple[str, float]], apply_awq_mod: bool = True):
|
||||
model_wrapper, transformer = _get_qwen_wrapper_and_transformer(model)
|
||||
model_wrapper.apply_awq_mod = apply_awq_mod
|
||||
|
||||
saved_config = None
|
||||
if hasattr(model, "model") and hasattr(model.model, "model_config"):
|
||||
saved_config = model.model.model_config
|
||||
model.model.model_config = None
|
||||
|
||||
model_wrapper.model = None
|
||||
try:
|
||||
ret_model = copy.deepcopy(model)
|
||||
finally:
|
||||
if saved_config is not None:
|
||||
model.model.model_config = saved_config
|
||||
model_wrapper.model = transformer
|
||||
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
if saved_config is not None:
|
||||
ret_model.model.model_config = saved_config
|
||||
ret_model_wrapper.model = transformer
|
||||
ret_model_wrapper.apply_awq_mod = apply_awq_mod
|
||||
ret_model_wrapper.loras = list(getattr(model_wrapper, "loras", []))
|
||||
|
||||
for lora_name, lora_strength in lora_configs:
|
||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||
if not lora_path or not os.path.isfile(lora_path):
|
||||
logger.warning("Skipping Qwen LoRA '%s' because it could not be found", lora_name)
|
||||
continue
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
return ret_model
|
||||
@@ -1,15 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import inspect
|
||||
|
||||
from ..services.wildcard_service import (
|
||||
contains_dynamic_syntax,
|
||||
get_wildcard_service,
|
||||
is_trigger_words_input,
|
||||
)
|
||||
|
||||
class _AllContainer:
|
||||
"""Container that accepts any key for dynamic input validation."""
|
||||
|
||||
def __contains__(self, item):
|
||||
return True
|
||||
class _PromptOptionalInputs:
|
||||
"""Lookup that preserves explicit optional inputs and dynamic trigger slots."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
return ("STRING", {"forceInput": True})
|
||||
def __init__(self, explicit_inputs: dict[str, tuple[str, dict[str, Any]]]) -> None:
|
||||
self._explicit_inputs = explicit_inputs
|
||||
|
||||
def __contains__(self, item: object) -> bool:
|
||||
if not isinstance(item, str):
|
||||
return False
|
||||
return item in self._explicit_inputs or is_trigger_words_input(item)
|
||||
|
||||
def __getitem__(self, key: str) -> tuple[str, dict[str, Any]]:
|
||||
if key in self._explicit_inputs:
|
||||
return self._explicit_inputs[key]
|
||||
if is_trigger_words_input(key):
|
||||
return (
|
||||
"STRING",
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Trigger words to prepend. Connect to add more inputs.",
|
||||
},
|
||||
)
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
class PromptLM:
|
||||
@@ -20,12 +43,19 @@ class PromptLM:
|
||||
DESCRIPTION = (
|
||||
"Encodes a text prompt using a CLIP model into an embedding that can be used "
|
||||
"to guide the diffusion model towards generating specific images. "
|
||||
"Supports dynamic trigger words inputs."
|
||||
"Supports dynamic trigger words inputs and runtime wildcard expansion."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
dyn_inputs = {
|
||||
optional_inputs: dict[str, tuple[str, dict[str, Any]]] = {
|
||||
"seed": (
|
||||
"INT",
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Optional seed for wildcard generation. Leave unconnected for non-deterministic wildcard expansion.",
|
||||
},
|
||||
),
|
||||
"trigger_words1": (
|
||||
"STRING",
|
||||
{
|
||||
@@ -35,10 +65,9 @@ class PromptLM:
|
||||
),
|
||||
}
|
||||
|
||||
# Bypass validation for dynamic inputs during graph execution
|
||||
stack = inspect.stack()
|
||||
if len(stack) > 2 and stack[2].function == "get_input_info":
|
||||
dyn_inputs = _AllContainer()
|
||||
optional_inputs = _PromptOptionalInputs(optional_inputs) # type: ignore[assignment]
|
||||
|
||||
return {
|
||||
"required": {
|
||||
@@ -46,8 +75,8 @@ class PromptLM:
|
||||
"AUTOCOMPLETE_TEXT_PROMPT,STRING",
|
||||
{
|
||||
"widgetType": "AUTOCOMPLETE_TEXT_PROMPT",
|
||||
"placeholder": "Enter prompt... /char, /artist for quick tag search",
|
||||
"tooltip": "The text to be encoded.",
|
||||
"placeholder": "Enter prompt... /character, /artist, /wildcard for quick search",
|
||||
"tooltip": "The text to be encoded. Wildcard references inserted with /wildcard are expanded at runtime.",
|
||||
},
|
||||
),
|
||||
"clip": (
|
||||
@@ -55,7 +84,7 @@ class PromptLM:
|
||||
{"tooltip": "The CLIP model used for encoding the text."},
|
||||
),
|
||||
},
|
||||
"optional": dyn_inputs,
|
||||
"optional": optional_inputs,
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "STRING")
|
||||
@@ -65,20 +94,39 @@ class PromptLM:
|
||||
)
|
||||
FUNCTION = "encode"
|
||||
|
||||
def encode(self, text: str, clip: Any, **kwargs):
|
||||
# Collect all trigger words from dynamic inputs
|
||||
@classmethod
|
||||
def IS_CHANGED(
|
||||
cls,
|
||||
text: str,
|
||||
clip: Any | None = None,
|
||||
seed: int | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
del clip, kwargs
|
||||
if contains_dynamic_syntax(text) and seed is None:
|
||||
return float("NaN")
|
||||
return False
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
clip: Any,
|
||||
seed: int | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
expanded_text = get_wildcard_service().expand_text(text, seed=seed)
|
||||
|
||||
trigger_words = []
|
||||
for key, value in kwargs.items():
|
||||
if key.startswith("trigger_words") and value:
|
||||
if is_trigger_words_input(key) and value:
|
||||
trigger_words.append(value)
|
||||
|
||||
# Build final prompt
|
||||
if trigger_words:
|
||||
prompt = ", ".join(trigger_words + [text])
|
||||
prompt = ", ".join(trigger_words + [expanded_text])
|
||||
else:
|
||||
prompt = text
|
||||
prompt = expanded_text
|
||||
|
||||
from nodes import CLIPTextEncode # type: ignore
|
||||
|
||||
conditioning = CLIPTextEncode().encode(clip, prompt)[0]
|
||||
return (conditioning, prompt)
|
||||
return (conditioning, prompt)
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.utils import calculate_recipe_fingerprint, sanitize_folder_name
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
import logging
|
||||
@@ -72,6 +77,13 @@ class SaveImageLM:
|
||||
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats.",
|
||||
},
|
||||
),
|
||||
"save_with_metadata": (
|
||||
"BOOLEAN",
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "When enabled, embeds generation parameters into the saved image metadata. Disable to skip writing generation metadata.",
|
||||
},
|
||||
),
|
||||
"add_counter_to_filename": (
|
||||
"BOOLEAN",
|
||||
{
|
||||
@@ -79,6 +91,13 @@ class SaveImageLM:
|
||||
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
|
||||
},
|
||||
),
|
||||
"save_as_recipe": (
|
||||
"BOOLEAN",
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Also saves each generated image as a LoRA Manager recipe.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
@@ -279,7 +298,12 @@ class SaveImageLM:
|
||||
key = parts[0]
|
||||
|
||||
if key == "seed" and "seed" in metadata_dict:
|
||||
filename = filename.replace(segment, str(metadata_dict.get("seed", "")))
|
||||
seed_value = metadata_dict.get("seed")
|
||||
if seed_value is not None:
|
||||
filename = filename.replace(segment, str(seed_value))
|
||||
else:
|
||||
# Fallback if seed was not captured by metadata collector
|
||||
filename = filename.replace(segment, "0")
|
||||
elif key == "width" and "size" in metadata_dict:
|
||||
size = metadata_dict.get("size", "x")
|
||||
w = size.split("x")[0] if isinstance(size, str) else size[0]
|
||||
@@ -290,12 +314,14 @@ class SaveImageLM:
|
||||
filename = filename.replace(segment, str(h))
|
||||
elif key == "pprompt" and "prompt" in metadata_dict:
|
||||
prompt = metadata_dict.get("prompt", "").replace("\n", " ")
|
||||
prompt = sanitize_folder_name(prompt)
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
prompt = prompt[:length]
|
||||
filename = filename.replace(segment, prompt.strip())
|
||||
elif key == "nprompt" and "negative_prompt" in metadata_dict:
|
||||
prompt = metadata_dict.get("negative_prompt", "").replace("\n", " ")
|
||||
prompt = sanitize_folder_name(prompt)
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
prompt = prompt[:length]
|
||||
@@ -309,6 +335,7 @@ class SaveImageLM:
|
||||
model = "model_unavailable"
|
||||
else:
|
||||
model = os.path.splitext(os.path.basename(model_value))[0]
|
||||
model = sanitize_folder_name(model)
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
model = model[:length]
|
||||
@@ -339,6 +366,203 @@ class SaveImageLM:
|
||||
|
||||
return filename
|
||||
|
||||
@staticmethod
|
||||
def _get_cached_model_by_name(scanner, name):
|
||||
cache = getattr(scanner, "_cache", None)
|
||||
if cache is None or not name:
|
||||
return None
|
||||
|
||||
candidates = [
|
||||
name,
|
||||
os.path.basename(name),
|
||||
os.path.splitext(os.path.basename(name))[0],
|
||||
]
|
||||
for model in getattr(cache, "raw_data", []):
|
||||
file_name = model.get("file_name")
|
||||
if file_name in candidates:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _build_recipe_loras(self, recipe_scanner, lora_stack):
|
||||
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack or "")
|
||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
||||
loras_data = []
|
||||
base_model_counts = {}
|
||||
|
||||
for name, strength in lora_matches:
|
||||
lora_info = self._get_cached_model_by_name(lora_scanner, name)
|
||||
civitai = (lora_info or {}).get("civitai") or {}
|
||||
civitai_model = civitai.get("model") or {}
|
||||
try:
|
||||
parsed_strength = float(strength)
|
||||
except (TypeError, ValueError):
|
||||
parsed_strength = 1.0
|
||||
|
||||
loras_data.append(
|
||||
{
|
||||
"file_name": name,
|
||||
"strength": parsed_strength,
|
||||
"hash": ((lora_info or {}).get("sha256") or "").lower(),
|
||||
"modelVersionId": civitai.get("id", 0),
|
||||
"modelName": civitai_model.get("name", name) if lora_info else "",
|
||||
"modelVersionName": civitai.get("name", "") if lora_info else "",
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
)
|
||||
|
||||
base_model = (lora_info or {}).get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
return lora_matches, loras_data, base_model_counts
|
||||
|
||||
def _build_recipe_checkpoint(self, recipe_scanner, checkpoint_raw):
|
||||
if not isinstance(checkpoint_raw, str) or not checkpoint_raw.strip():
|
||||
return None
|
||||
|
||||
checkpoint_name = checkpoint_raw.strip()
|
||||
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
|
||||
checkpoint_scanner = getattr(recipe_scanner, "_checkpoint_scanner", None)
|
||||
checkpoint_info = self._get_cached_model_by_name(
|
||||
checkpoint_scanner, checkpoint_name
|
||||
)
|
||||
|
||||
if not checkpoint_info:
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"name": checkpoint_name,
|
||||
"file_name": file_name,
|
||||
"hash": self.get_checkpoint_hash(checkpoint_name) or "",
|
||||
}
|
||||
|
||||
civitai = checkpoint_info.get("civitai") or {}
|
||||
civitai_model = civitai.get("model") or {}
|
||||
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
|
||||
cached_file_name = (
|
||||
checkpoint_info.get("file_name")
|
||||
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
|
||||
or file_name
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"modelId": civitai_model.get("id", 0),
|
||||
"modelVersionId": civitai.get("id", 0),
|
||||
"name": civitai_model.get("name")
|
||||
or checkpoint_info.get("model_name")
|
||||
or checkpoint_name,
|
||||
"version": civitai.get("name", ""),
|
||||
"hash": (
|
||||
checkpoint_info.get("sha256") or checkpoint_info.get("hash") or ""
|
||||
).lower(),
|
||||
"file_name": cached_file_name,
|
||||
"modelName": civitai_model.get("name", ""),
|
||||
"modelVersionName": civitai.get("name", ""),
|
||||
"baseModel": checkpoint_info.get("base_model")
|
||||
or civitai.get("baseModel", ""),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _derive_recipe_name(lora_matches):
|
||||
recipe_name_parts = [
|
||||
f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]
|
||||
]
|
||||
return "_".join(recipe_name_parts) or "recipe"
|
||||
|
||||
@staticmethod
|
||||
def _sync_recipe_cache(recipe_scanner, recipe_data, json_path):
|
||||
cache = getattr(recipe_scanner, "_cache", None)
|
||||
if cache is not None:
|
||||
cache.raw_data.append(recipe_data)
|
||||
cache.sorted_by_name = sorted(
|
||||
cache.raw_data, key=lambda item: item.get("title", "").lower()
|
||||
)
|
||||
cache.sorted_by_date = sorted(
|
||||
cache.raw_data,
|
||||
key=lambda item: (
|
||||
item.get("modified", item.get("created_date", 0)),
|
||||
item.get("file_path", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
recipe_scanner._update_folder_metadata(cache)
|
||||
recipe_scanner._update_fts_index_for_recipe(recipe_data, "add")
|
||||
|
||||
recipe_id = str(recipe_data.get("id", ""))
|
||||
if recipe_id:
|
||||
recipe_scanner._json_path_map[recipe_id] = json_path
|
||||
persistent_cache = getattr(recipe_scanner, "_persistent_cache", None)
|
||||
if persistent_cache:
|
||||
persistent_cache.update_recipe(recipe_data, json_path)
|
||||
|
||||
def _save_image_as_recipe(self, file_path, metadata_dict):
|
||||
if not metadata_dict:
|
||||
raise ValueError("No generation metadata found")
|
||||
|
||||
recipe_scanner = ServiceRegistry.get_service_sync("recipe_scanner")
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir:
|
||||
raise RuntimeError("Recipes directory unavailable")
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
optimized_image, extension = ExifUtils.optimize_image(
|
||||
image_data=file_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=True,
|
||||
)
|
||||
image_path = os.path.normpath(os.path.join(recipes_dir, f"{recipe_id}{extension}"))
|
||||
with open(image_path, "wb") as file_obj:
|
||||
file_obj.write(optimized_image)
|
||||
|
||||
lora_stack = metadata_dict.get("loras", "")
|
||||
lora_matches, loras_data, base_model_counts = self._build_recipe_loras(
|
||||
recipe_scanner, lora_stack
|
||||
)
|
||||
checkpoint_entry = self._build_recipe_checkpoint(
|
||||
recipe_scanner, metadata_dict.get("checkpoint")
|
||||
)
|
||||
most_common_base_model = (
|
||||
max(base_model_counts.items(), key=lambda item: item[1])[0]
|
||||
if base_model_counts
|
||||
else ""
|
||||
)
|
||||
current_time = time.time()
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
"file_path": image_path,
|
||||
"title": self._derive_recipe_name(lora_matches),
|
||||
"modified": current_time,
|
||||
"created_date": current_time,
|
||||
"base_model": most_common_base_model
|
||||
or (checkpoint_entry or {}).get("baseModel", ""),
|
||||
"loras": loras_data,
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata_dict.items()
|
||||
if key not in ["checkpoint", "loras"]
|
||||
},
|
||||
"loras_stack": lora_stack,
|
||||
"fingerprint": calculate_recipe_fingerprint(loras_data),
|
||||
}
|
||||
if checkpoint_entry:
|
||||
recipe_data["checkpoint"] = checkpoint_entry
|
||||
|
||||
json_path = os.path.normpath(
|
||||
os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
)
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
||||
self._sync_recipe_cache(recipe_scanner, recipe_data, json_path)
|
||||
|
||||
def save_images(
|
||||
self,
|
||||
images,
|
||||
@@ -350,7 +574,9 @@ class SaveImageLM:
|
||||
lossless_webp=True,
|
||||
quality=100,
|
||||
embed_workflow=False,
|
||||
save_with_metadata=True,
|
||||
add_counter_to_filename=True,
|
||||
save_as_recipe=False,
|
||||
):
|
||||
"""Save images with metadata"""
|
||||
results = []
|
||||
@@ -382,7 +608,7 @@ class SaveImageLM:
|
||||
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
|
||||
|
||||
# Generate filename with counter if needed
|
||||
base_filename = filename
|
||||
base_filename = filename.replace("%batch_num%", str(i))
|
||||
if add_counter_to_filename:
|
||||
# Use counter + i to ensure unique filenames for all images in batch
|
||||
current_counter = counter + i
|
||||
@@ -421,7 +647,7 @@ class SaveImageLM:
|
||||
try:
|
||||
if file_format == "png":
|
||||
assert pnginfo is not None
|
||||
if metadata:
|
||||
if save_with_metadata and metadata:
|
||||
pnginfo.add_text("parameters", metadata)
|
||||
if embed_workflow and extra_pnginfo is not None:
|
||||
workflow_json = json.dumps(extra_pnginfo["workflow"])
|
||||
@@ -430,7 +656,7 @@ class SaveImageLM:
|
||||
img.save(file_path, format="PNG", **save_kwargs)
|
||||
elif file_format == "jpeg":
|
||||
# For JPEG, use piexif
|
||||
if metadata:
|
||||
if save_with_metadata and metadata:
|
||||
try:
|
||||
exif_dict = {
|
||||
"Exif": {
|
||||
@@ -448,7 +674,7 @@ class SaveImageLM:
|
||||
# For WebP, use piexif for metadata
|
||||
exif_dict = {}
|
||||
|
||||
if metadata:
|
||||
if save_with_metadata and metadata:
|
||||
exif_dict["Exif"] = {
|
||||
piexif.ExifIFD.UserComment: b"UNICODE\0"
|
||||
+ metadata.encode("utf-16be")
|
||||
@@ -469,6 +695,14 @@ class SaveImageLM:
|
||||
|
||||
img.save(file_path, format="WEBP", **save_kwargs)
|
||||
|
||||
if save_as_recipe:
|
||||
try:
|
||||
self._save_image_as_recipe(file_path, metadata_dict)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to save image as recipe: %s", e, exc_info=True
|
||||
)
|
||||
|
||||
results.append(
|
||||
{"filename": file, "subfolder": subfolder, "type": self.type}
|
||||
)
|
||||
@@ -489,7 +723,9 @@ class SaveImageLM:
|
||||
lossless_webp=True,
|
||||
quality=100,
|
||||
embed_workflow=False,
|
||||
save_with_metadata=True,
|
||||
add_counter_to_filename=True,
|
||||
save_as_recipe=False,
|
||||
):
|
||||
"""Process and save image with metadata"""
|
||||
# Make sure the output directory exists
|
||||
@@ -516,7 +752,12 @@ class SaveImageLM:
|
||||
lossless_webp,
|
||||
quality,
|
||||
embed_workflow,
|
||||
save_with_metadata,
|
||||
add_counter_to_filename,
|
||||
save_as_recipe,
|
||||
)
|
||||
|
||||
return (images,)
|
||||
return {
|
||||
"result": (images,),
|
||||
"ui": {"images": results},
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..services.wildcard_service import contains_dynamic_syntax, get_wildcard_service
|
||||
|
||||
|
||||
class TextLM:
|
||||
"""A simple text node with autocomplete support."""
|
||||
|
||||
NAME = "Text (LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = (
|
||||
"A simple text input node with autocomplete support for tags and styles."
|
||||
"A simple text input node with autocomplete support for tags, styles, and wildcard expansion."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -15,8 +20,17 @@ class TextLM:
|
||||
"AUTOCOMPLETE_TEXT_PROMPT,STRING",
|
||||
{
|
||||
"widgetType": "AUTOCOMPLETE_TEXT_PROMPT",
|
||||
"placeholder": "Enter text... /char, /artist for quick tag search",
|
||||
"tooltip": "The text output.",
|
||||
"placeholder": "Enter text... /character, /artist, /wildcard for quick search",
|
||||
"tooltip": "The text output. Wildcard references inserted with /wildcard are expanded at runtime.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
"INT",
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Optional seed for wildcard generation. Leave unconnected for non-deterministic wildcard expansion.",
|
||||
},
|
||||
),
|
||||
},
|
||||
@@ -24,10 +38,14 @@ class TextLM:
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("STRING",)
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The text output.",
|
||||
)
|
||||
OUTPUT_TOOLTIPS = ("The text output.",)
|
||||
FUNCTION = "process"
|
||||
|
||||
def process(self, text: str):
|
||||
return (text,)
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, text: str, seed: int | None = None):
|
||||
if contains_dynamic_syntax(text) and seed is None:
|
||||
return float("NaN")
|
||||
return False
|
||||
|
||||
def process(self, text: str, seed: int | None = None):
|
||||
return (get_wildcard_service().expand_text(text, seed=seed),)
|
||||
|
||||
@@ -76,6 +76,9 @@ class TriggerWordToggleLM:
|
||||
# Filter out empty strings and return as set
|
||||
return set(word for word in words if word)
|
||||
|
||||
def _group_has_child_items(self, item):
|
||||
return isinstance(item, dict) and isinstance(item.get("items"), list)
|
||||
|
||||
def process_trigger_words(
|
||||
self,
|
||||
id,
|
||||
@@ -112,7 +115,11 @@ class TriggerWordToggleLM:
|
||||
|
||||
if isinstance(trigger_data, list):
|
||||
if group_mode:
|
||||
if allow_strength_adjustment:
|
||||
if any(self._group_has_child_items(item) for item in trigger_data):
|
||||
filtered_groups = self._process_group_items(
|
||||
trigger_data, allow_strength_adjustment
|
||||
)
|
||||
elif allow_strength_adjustment:
|
||||
parsed_items = [
|
||||
self._parse_trigger_item(
|
||||
item, allow_strength_adjustment
|
||||
@@ -174,6 +181,41 @@ class TriggerWordToggleLM:
|
||||
|
||||
return (filtered_triggers,)
|
||||
|
||||
def _process_group_items(self, trigger_data, allow_strength_adjustment):
|
||||
filtered_groups = []
|
||||
|
||||
for item in trigger_data:
|
||||
group = self._parse_trigger_item(item, allow_strength_adjustment)
|
||||
if not group["text"] or not group["active"]:
|
||||
continue
|
||||
|
||||
raw_items = item.get("items") if isinstance(item, dict) else None
|
||||
if isinstance(raw_items, list):
|
||||
active_items = []
|
||||
for raw_item in raw_items:
|
||||
child = self._parse_trigger_item(
|
||||
raw_item, allow_strength_adjustment=False
|
||||
)
|
||||
if child["text"] and child["active"]:
|
||||
active_items.append(child["text"])
|
||||
|
||||
if not active_items:
|
||||
continue
|
||||
|
||||
group_text = ", ".join(active_items)
|
||||
else:
|
||||
group_text = group["text"]
|
||||
|
||||
filtered_groups.append(
|
||||
self._format_word_output(
|
||||
group_text,
|
||||
group["strength"],
|
||||
allow_strength_adjustment,
|
||||
)
|
||||
)
|
||||
|
||||
return filtered_groups
|
||||
|
||||
def _parse_trigger_item(self, item, allow_strength_adjustment):
|
||||
text = (item.get("text") or "").strip()
|
||||
active = bool(item.get("active", False))
|
||||
|
||||
205
py/nodes/unet_loader.py
Normal file
205
py/nodes/unet_loader.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
import comfy.sd # type: ignore
|
||||
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UNETLoaderLM:
|
||||
"""UNET Loader with support for extra folder paths
|
||||
|
||||
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
|
||||
extra folder paths, providing a unified interface for UNET loading.
|
||||
Supports both regular diffusion models and GGUF format models.
|
||||
"""
|
||||
|
||||
NAME = "Unet Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
# Get list of unet names from scanner (includes extra folder paths)
|
||||
unet_names = s._get_unet_names()
|
||||
return {
|
||||
"required": {
|
||||
"unet_name": (
|
||||
unet_names,
|
||||
{"tooltip": "The name of the diffusion model to load."},
|
||||
),
|
||||
"weight_dtype": (
|
||||
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
|
||||
{"tooltip": "The dtype to use for the model weights."},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
@classmethod
|
||||
def _get_unet_names(cls) -> List[str]:
|
||||
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||
try:
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import asyncio
|
||||
|
||||
async def _get_names():
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get all model roots for calculating relative paths
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Filter only diffusion_model type and format names
|
||||
names = []
|
||||
for item in cache.raw_data:
|
||||
if item.get("sub_type") == "diffusion_model":
|
||||
file_path = item.get("file_path", "")
|
||||
if file_path:
|
||||
# Format using relative path with OS-native separator
|
||||
formatted_name = _format_model_name_for_comfyui(
|
||||
file_path, model_roots
|
||||
)
|
||||
if formatted_name:
|
||||
names.append(formatted_name)
|
||||
|
||||
return sorted(names)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_names())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_get_names())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet names: {e}")
|
||||
return []
|
||||
|
||||
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
|
||||
"""Load a diffusion model by name, supporting extra folder paths
|
||||
|
||||
Args:
|
||||
unet_name: The name of the diffusion model to load (relative path with extension)
|
||||
weight_dtype: The dtype to use for model weights
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL,)
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Get absolute path from cache using ComfyUI-style name
|
||||
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
|
||||
|
||||
if metadata is None:
|
||||
raise FileNotFoundError(
|
||||
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
|
||||
"Make sure the model is indexed and try again."
|
||||
)
|
||||
|
||||
# Check if it's a GGUF model
|
||||
if unet_path.endswith(".gguf"):
|
||||
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
|
||||
|
||||
# Load regular diffusion model using ComfyUI's API
|
||||
logger.info(f"Loading diffusion model from: {unet_path}")
|
||||
|
||||
# Build model options based on weight_dtype
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
def _load_gguf_unet(
|
||||
self, unet_path: str, unet_name: str, weight_dtype: str
|
||||
) -> Tuple:
|
||||
"""Load a GGUF format diffusion model
|
||||
|
||||
Args:
|
||||
unet_path: Absolute path to the GGUF file
|
||||
unet_name: Name of the model for error messages
|
||||
weight_dtype: The dtype to use for model weights
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL,)
|
||||
"""
|
||||
import torch
|
||||
from .gguf_import_helper import get_gguf_modules
|
||||
|
||||
# Get ComfyUI-GGUF modules using helper (handles various import scenarios)
|
||||
try:
|
||||
loader_module, ops_module, nodes_module = get_gguf_modules()
|
||||
gguf_sd_loader = getattr(loader_module, "gguf_sd_loader")
|
||||
GGMLOps = getattr(ops_module, "GGMLOps")
|
||||
GGUFModelPatcher = getattr(nodes_module, "GGUFModelPatcher")
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Cannot load GGUF model '{unet_name}'. {str(e)}")
|
||||
|
||||
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
|
||||
|
||||
try:
|
||||
# Load GGUF state dict
|
||||
sd, extra = gguf_sd_loader(unet_path)
|
||||
|
||||
# Prepare kwargs for metadata if supported
|
||||
kwargs = {}
|
||||
import inspect
|
||||
|
||||
valid_params = inspect.signature(
|
||||
comfy.sd.load_diffusion_model_state_dict
|
||||
).parameters
|
||||
if "metadata" in valid_params:
|
||||
kwargs["metadata"] = extra.get("metadata", {})
|
||||
|
||||
# Setup custom operations with GGUF support
|
||||
ops = GGMLOps()
|
||||
|
||||
# Handle weight_dtype for GGUF models
|
||||
if weight_dtype in ("default", None):
|
||||
ops.Linear.dequant_dtype = None
|
||||
elif weight_dtype in ["target"]:
|
||||
ops.Linear.dequant_dtype = weight_dtype
|
||||
else:
|
||||
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
|
||||
|
||||
# Load the model
|
||||
model = comfy.sd.load_diffusion_model_state_dict(
|
||||
sd, model_options={"custom_operations": ops}, **kwargs
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Could not detect model type for GGUF diffusion model: {unet_path}"
|
||||
)
|
||||
|
||||
# Wrap with GGUFModelPatcher
|
||||
model = GGUFModelPatcher.clone(model)
|
||||
|
||||
return (model,)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
|
||||
)
|
||||
@@ -44,11 +44,29 @@ import folder_paths # type: ignore
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_lora_syntax_format():
|
||||
try:
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
return get_settings_manager().get("lora_syntax_format", "legacy")
|
||||
except Exception:
|
||||
return "legacy"
|
||||
|
||||
|
||||
def apply_lora_syntax_format(name):
|
||||
fmt = get_lora_syntax_format()
|
||||
if fmt == "legacy":
|
||||
return name.replace("\\", "/").rstrip("/").split("/")[-1]
|
||||
return name
|
||||
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
normalized = lora_path.replace("\\", "/")
|
||||
basename = os.path.basename(normalized)
|
||||
name_no_ext = os.path.splitext(basename)[0]
|
||||
dirname = os.path.dirname(normalized)
|
||||
if dirname and dirname not in (".", "/") and not normalized.startswith("/"):
|
||||
return apply_lora_syntax_format(f"{dirname}/{name_no_ext}")
|
||||
return apply_lora_syntax_format(name_no_ext)
|
||||
|
||||
|
||||
def get_loras_list(kwargs):
|
||||
@@ -158,3 +176,24 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
|
||||
|
||||
def detect_nunchaku_model_kind(model):
|
||||
"""Return the supported Nunchaku model kind for a Comfy model, if any."""
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
wrapper_name = model_wrapper.__class__.__name__
|
||||
if wrapper_name == "ComfyFluxWrapper":
|
||||
return "flux"
|
||||
|
||||
inner_model = getattr(model_wrapper, "model", None)
|
||||
inner_name = inner_model.__class__.__name__ if inner_model is not None else ""
|
||||
if wrapper_name.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
return "qwen_image"
|
||||
if inner_name.endswith("NunchakuQwenImageTransformer2DModel"):
|
||||
return "qwen_image"
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
import os
|
||||
from ..utils.utils import get_lora_info_absolute
|
||||
from ..config import config
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _relpath_within_loras(abs_path):
|
||||
"""Return abs_path relative to the first matching lora root, or basename as fallback."""
|
||||
all_roots = list(config.loras_roots or []) + list(config.extra_loras_roots or [])
|
||||
for root in all_roots:
|
||||
try:
|
||||
return os.path.relpath(abs_path, root)
|
||||
except ValueError:
|
||||
continue
|
||||
return os.path.basename(abs_path)
|
||||
|
||||
class WanVideoLoraSelectLM:
|
||||
NAME = "WanVideo Lora Select (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
@@ -56,13 +68,13 @@ class WanVideoLoraSelectLM:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||
|
||||
# Create lora item for WanVideo format
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"path": lora_path,
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"name": os.path.splitext(_relpath_within_loras(lora_path))[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
|
||||
@@ -1,11 +1,23 @@
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
import os
|
||||
from ..utils.utils import get_lora_info_absolute
|
||||
from ..config import config
|
||||
from .utils import any_type
|
||||
import logging
|
||||
|
||||
# 初始化日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _relpath_within_loras(abs_path):
|
||||
"""Return abs_path relative to the first matching lora root, or basename as fallback."""
|
||||
all_roots = list(config.loras_roots or []) + list(config.extra_loras_roots or [])
|
||||
for root in all_roots:
|
||||
try:
|
||||
return os.path.relpath(abs_path, root)
|
||||
except ValueError:
|
||||
continue
|
||||
return os.path.basename(abs_path)
|
||||
|
||||
# 定义新节点的类
|
||||
class WanVideoLoraTextSelectLM:
|
||||
# 节点在UI中显示的名称
|
||||
@@ -87,12 +99,12 @@ class WanVideoLoraTextSelectLM:
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_path, trigger_words = get_lora_info(lora_name_raw)
|
||||
lora_path, trigger_words = get_lora_info_absolute(lora_name_raw)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"path": lora_path,
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"name": os.path.splitext(_relpath_within_loras(lora_path))[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from ..config import config
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
from ..utils.constants import VALID_LORA_TYPES, VALID_CHECKPOINT_SUB_TYPES
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,9 +58,52 @@ class RecipeMetadataParser(ABC):
|
||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
if not civitai_info or error_msg == "Model not found":
|
||||
# Model not found or deleted
|
||||
lora_entry['isDeleted'] = True
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
# CivitAI may fail to resolve a hash that is still being
|
||||
# computed (known CivitAI issue). Before marking as deleted,
|
||||
# try to reconcile with a local model that has the same
|
||||
# filename and matching AutoV3 hash.
|
||||
reconciled = False
|
||||
file_name = lora_entry.get("file_name")
|
||||
if file_name and recipe_scanner and hash_value:
|
||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
||||
if lora_scanner:
|
||||
try:
|
||||
# Local import to avoid circular dependency:
|
||||
# base.py → file_utils → settings_manager → ...
|
||||
# → recipe_scanner → enrichment → base.py
|
||||
from ..utils.file_utils import calculate_autov3 # fmt: skip
|
||||
cache = await lora_scanner.get_cached_data()
|
||||
for item in getattr(cache, "raw_data", []):
|
||||
if item.get("file_name") == file_name:
|
||||
local_path = item.get("file_path")
|
||||
if local_path and os.path.exists(local_path):
|
||||
local_autov3 = calculate_autov3(local_path)
|
||||
if local_autov3 and local_autov3 == hash_value:
|
||||
lora_entry["existsLocally"] = True
|
||||
lora_entry["localPath"] = local_path
|
||||
lora_entry["hash"] = item.get("sha256", hash_value)
|
||||
if "preview_url" in item:
|
||||
lora_entry["thumbnailUrl"] = config.get_preview_static_url(item["preview_url"])
|
||||
civ = item.get("civitai") or {}
|
||||
if isinstance(civ, dict):
|
||||
if civ.get("id") is not None:
|
||||
lora_entry["id"] = civ["id"]
|
||||
if civ.get("modelId") is not None:
|
||||
lora_entry["modelId"] = civ["modelId"]
|
||||
if civ.get("name"):
|
||||
lora_entry["version"] = civ["name"]
|
||||
# model_name is the CivitAI model display
|
||||
# name stored directly in the cache column.
|
||||
cached_model_name = item.get("model_name")
|
||||
if cached_model_name:
|
||||
lora_entry["name"] = cached_model_name
|
||||
reconciled = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if not reconciled:
|
||||
lora_entry['isDeleted'] = True
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
return lora_entry
|
||||
|
||||
# Get model type and validate
|
||||
@@ -173,6 +216,20 @@ class RecipeMetadataParser(ABC):
|
||||
checkpoint['isDeleted'] = True
|
||||
return checkpoint
|
||||
|
||||
# Validate that the model type is actually a checkpoint.
|
||||
# Unlike populate_lora_from_civitai which has this check,
|
||||
# this function was missing type validation — allowing LoRA
|
||||
# version data to be saved as the recipe's checkpoint when the
|
||||
# wrong version ID was passed downstream (fixed in v2.7+).
|
||||
model_type = civitai_data.get('model', {}).get('type', '').lower()
|
||||
if model_type not in VALID_CHECKPOINT_SUB_TYPES:
|
||||
logger.warning(
|
||||
f"Cannot populate checkpoint: model version {civitai_data.get('id')} "
|
||||
f"has type '{model_type}', expected one of {VALID_CHECKPOINT_SUB_TYPES}. "
|
||||
f"Skipping checkpoint enrichment."
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
if 'model' in civitai_data and 'name' in civitai_data['model']:
|
||||
checkpoint['name'] = civitai_data['model']['name']
|
||||
|
||||
|
||||
@@ -13,4 +13,5 @@ GEN_PARAM_KEYS = [
|
||||
'seed',
|
||||
'size',
|
||||
'clip_skip',
|
||||
'denoising_strength',
|
||||
]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from .merger import GenParamsMerger
|
||||
from .base import RecipeMetadataParser
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
from ..utils.civitai_utils import extract_civitai_image_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,54 +16,65 @@ class RecipeEnricher:
|
||||
async def enrich_recipe(
|
||||
recipe: Dict[str, Any],
|
||||
civitai_client: Any,
|
||||
request_params: Optional[Dict[str, Any]] = None
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
prefetched_civitai_meta_raw: Optional[Dict[str, Any]] = None,
|
||||
prefetched_model_version_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Enrich a recipe dictionary in-place with metadata from Civitai and embedded params.
|
||||
|
||||
|
||||
Args:
|
||||
recipe: The recipe dictionary to enrich. Must have 'gen_params' initialized.
|
||||
civitai_client: Authenticated Civitai client instance.
|
||||
request_params: (Optional) Parameters from a user request (e.g. import).
|
||||
|
||||
prefetched_civitai_meta_raw: (Optional) Pre-fetched raw meta from Civitai
|
||||
get_image_info, avoiding a duplicate API call.
|
||||
prefetched_model_version_id: (Optional) Pre-fetched model version ID.
|
||||
|
||||
Returns:
|
||||
bool: True if the recipe was modified, False otherwise.
|
||||
"""
|
||||
updated = False
|
||||
gen_params = recipe.get("gen_params", {})
|
||||
|
||||
# 1. Fetch Civitai Info if available
|
||||
|
||||
# 1. Obtain Civitai metadata
|
||||
civitai_meta = None
|
||||
model_version_id = None
|
||||
|
||||
source_url = recipe.get("source_url") or recipe.get("source_path", "")
|
||||
|
||||
# Check if it's a Civitai image URL
|
||||
image_id_match = re.search(r'civitai\.com/images/(\d+)', str(source_url))
|
||||
if image_id_match:
|
||||
image_id = image_id_match.group(1)
|
||||
try:
|
||||
image_info = await civitai_client.get_image_info(image_id)
|
||||
if image_info:
|
||||
# Handle nested meta often found in Civitai API responses
|
||||
raw_meta = image_info.get("meta")
|
||||
if isinstance(raw_meta, dict):
|
||||
if "meta" in raw_meta and isinstance(raw_meta["meta"], dict):
|
||||
civitai_meta = raw_meta["meta"]
|
||||
else:
|
||||
civitai_meta = raw_meta
|
||||
|
||||
model_version_id = image_info.get("modelVersionId")
|
||||
|
||||
# If not at top level, check resources in meta
|
||||
if not model_version_id and civitai_meta:
|
||||
resources = civitai_meta.get("civitaiResources", [])
|
||||
for res in resources:
|
||||
if res.get("type") == "checkpoint":
|
||||
model_version_id = res.get("modelVersionId")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch Civitai image info: {e}")
|
||||
model_version_id = prefetched_model_version_id
|
||||
|
||||
source_path = recipe.get("source_path", "")
|
||||
|
||||
if prefetched_civitai_meta_raw is not None:
|
||||
raw_meta = prefetched_civitai_meta_raw
|
||||
if isinstance(raw_meta, dict):
|
||||
if "meta" in raw_meta and isinstance(raw_meta["meta"], dict):
|
||||
civitai_meta = raw_meta["meta"]
|
||||
else:
|
||||
civitai_meta = raw_meta
|
||||
else:
|
||||
image_id = extract_civitai_image_id(str(source_path))
|
||||
if image_id:
|
||||
try:
|
||||
image_info = await civitai_client.get_image_info(
|
||||
image_id, source_url=str(source_path)
|
||||
)
|
||||
if image_info:
|
||||
raw_meta = image_info.get("meta")
|
||||
if isinstance(raw_meta, dict):
|
||||
if "meta" in raw_meta and isinstance(raw_meta["meta"], dict):
|
||||
civitai_meta = raw_meta["meta"]
|
||||
else:
|
||||
civitai_meta = raw_meta
|
||||
|
||||
model_version_id = image_info.get("modelVersionId")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch Civitai image info: {e}")
|
||||
|
||||
if not model_version_id and civitai_meta:
|
||||
resources = civitai_meta.get("civitaiResources", [])
|
||||
for res in resources:
|
||||
if res.get("type") == "checkpoint":
|
||||
model_version_id = res.get("modelVersionId")
|
||||
break
|
||||
|
||||
# 2. Merge Parameters
|
||||
# Priority: request_params > civitai_meta > embedded (existing gen_params)
|
||||
@@ -179,27 +190,42 @@ class RecipeEnricher:
|
||||
existing_cp = recipe.get("checkpoint")
|
||||
if existing_cp is None:
|
||||
existing_cp = {}
|
||||
|
||||
# Extract baseModel from raw civitai_info before populate_checkpoint_from_civitai
|
||||
# (populate may reject non-checkpoint types and lose this data)
|
||||
base_model_from_civitai: str = ""
|
||||
if isinstance(civitai_info, dict):
|
||||
base_model_from_civitai = civitai_info.get("baseModel", "") or ""
|
||||
elif isinstance(civitai_info, tuple) and len(civitai_info) > 0 and isinstance(civitai_info[0], dict):
|
||||
base_model_from_civitai = civitai_info[0].get("baseModel", "") or ""
|
||||
|
||||
checkpoint_data = await RecipeMetadataParser.populate_checkpoint_from_civitai(existing_cp, civitai_info)
|
||||
# 1. First, resolve base_model using full data before we format it away
|
||||
|
||||
# 1. Resolve base_model from checkpoint_data first, then fall back to raw civitai_info
|
||||
current_base_model = recipe.get("base_model")
|
||||
resolved_base_model = checkpoint_data.get("baseModel")
|
||||
resolved_base_model = checkpoint_data.get("baseModel") or base_model_from_civitai
|
||||
if resolved_base_model:
|
||||
# Update if empty OR if it matches our generic prefix but is less specific
|
||||
is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"]
|
||||
if is_generic and resolved_base_model != current_base_model:
|
||||
recipe["base_model"] = resolved_base_model
|
||||
|
||||
# 2. Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName
|
||||
formatted_checkpoint = {
|
||||
"type": "checkpoint",
|
||||
"modelId": checkpoint_data.get("modelId"),
|
||||
"modelVersionId": checkpoint_data.get("id") or checkpoint_data.get("modelVersionId"),
|
||||
"modelName": checkpoint_data.get("name"), # In base.py, 'name' is populated from civitai_data['model']['name']
|
||||
"modelVersionName": checkpoint_data.get("version") # In base.py, 'version' is populated from civitai_data['name']
|
||||
}
|
||||
# Remove None values
|
||||
recipe["checkpoint"] = {k: v for k, v in formatted_checkpoint.items() if v is not None}
|
||||
|
||||
|
||||
# 2. Only format and save checkpoint if it has real data (not just type after type rejection)
|
||||
has_checkpoint_data = any([
|
||||
checkpoint_data.get("modelId"),
|
||||
checkpoint_data.get("id") or checkpoint_data.get("modelVersionId"),
|
||||
checkpoint_data.get("name"),
|
||||
checkpoint_data.get("version"),
|
||||
])
|
||||
if has_checkpoint_data:
|
||||
formatted_checkpoint = {
|
||||
"type": "checkpoint",
|
||||
"modelId": checkpoint_data.get("modelId"),
|
||||
"modelVersionId": checkpoint_data.get("id") or checkpoint_data.get("modelVersionId"),
|
||||
"modelName": checkpoint_data.get("name"),
|
||||
"modelVersionName": checkpoint_data.get("version"),
|
||||
}
|
||||
recipe["checkpoint"] = {k: v for k, v in formatted_checkpoint.items() if v is not None}
|
||||
|
||||
return True
|
||||
else:
|
||||
# Fallback to name extraction if we don't already have one
|
||||
|
||||
@@ -7,6 +7,7 @@ from .parsers import (
|
||||
MetaFormatParser,
|
||||
AutomaticMetadataParser,
|
||||
CivitaiApiMetadataParser,
|
||||
SuiImageParamsParser,
|
||||
)
|
||||
from .base import RecipeMetadataParser
|
||||
|
||||
@@ -55,6 +56,13 @@ class RecipeParserFactory:
|
||||
# If JSON parsing fails, move on to other parsers
|
||||
pass
|
||||
|
||||
# Try SuiImageParamsParser for SuiImage metadata format
|
||||
try:
|
||||
if SuiImageParamsParser().is_metadata_matching(metadata_str):
|
||||
return SuiImageParamsParser()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check other parsers that expect string input
|
||||
if RecipeFormatParser().is_metadata_matching(metadata_str):
|
||||
return RecipeFormatParser()
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import logging
|
||||
|
||||
from .constants import GEN_PARAM_KEYS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenParamsMerger:
|
||||
"""Utility to merge generation parameters from multiple sources with priority."""
|
||||
|
||||
ALLOWED_KEYS = set(GEN_PARAM_KEYS)
|
||||
|
||||
BLACKLISTED_KEYS = {
|
||||
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
|
||||
"draft", "extra", "width", "height", "process", "quantity", "workflow",
|
||||
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
|
||||
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
|
||||
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
|
||||
"checkpoint", "checksum", "model_checksum"
|
||||
"checkpoint", "checksum", "model_checksum", "raw_metadata",
|
||||
}
|
||||
|
||||
|
||||
NORMALIZATION_MAPPING = {
|
||||
# Civitai specific
|
||||
"cfg": "cfg_scale",
|
||||
"cfgScale": "cfg_scale",
|
||||
"clipSkip": "clip_skip",
|
||||
"negativePrompt": "negative_prompt",
|
||||
# Case variations
|
||||
"Sampler": "sampler",
|
||||
"sampler_name": "sampler",
|
||||
"scheduler": "sampler",
|
||||
"Steps": "steps",
|
||||
"Seed": "seed",
|
||||
"Size": "size",
|
||||
@@ -36,63 +42,40 @@ class GenParamsMerger:
|
||||
def merge(
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
civitai_meta: Optional[Dict[str, Any]] = None,
|
||||
embedded_metadata: Optional[Dict[str, Any]] = None
|
||||
embedded_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge generation parameters from three sources.
|
||||
|
||||
Priority: request_params > civitai_meta > embedded_metadata
|
||||
|
||||
Args:
|
||||
request_params: Params provided directly in the import request
|
||||
civitai_meta: Params from Civitai Image API 'meta' field
|
||||
embedded_metadata: Params extracted from image EXIF/embedded metadata
|
||||
|
||||
Returns:
|
||||
Merged parameters dictionary
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. Start with embedded metadata (lowest priority)
|
||||
Priority: request_params > civitai_meta > embedded_metadata
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
|
||||
if embedded_metadata:
|
||||
# If it's a full recipe metadata, we use its gen_params
|
||||
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
|
||||
if "gen_params" in embedded_metadata and isinstance(
|
||||
embedded_metadata["gen_params"], dict
|
||||
):
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
|
||||
else:
|
||||
# Otherwise assume the dict itself contains gen_params
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata)
|
||||
|
||||
# 2. Layer Civitai meta (medium priority)
|
||||
if civitai_meta:
|
||||
GenParamsMerger._update_normalized(result, civitai_meta)
|
||||
|
||||
# 3. Layer request params (highest priority)
|
||||
if request_params:
|
||||
GenParamsMerger._update_normalized(result, request_params)
|
||||
|
||||
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
|
||||
final_result = {}
|
||||
for k, v in result.items():
|
||||
if k in GenParamsMerger.BLACKLISTED_KEYS:
|
||||
continue
|
||||
if k in GenParamsMerger.NORMALIZATION_MAPPING:
|
||||
continue
|
||||
final_result[k] = v
|
||||
|
||||
return final_result
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
"""Update target dict with normalized keys from source."""
|
||||
for k, v in source.items():
|
||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
|
||||
target[normalized_key] = v
|
||||
# Also keep the original key for now if it's not the same,
|
||||
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
|
||||
# Actually, if we rename it, we should probably NOT keep both in 'target'
|
||||
# because we want to filter them out at the end anyway.
|
||||
if normalized_key != k:
|
||||
# If we are overwriting an existing snake_case key with a camelCase one's value,
|
||||
# that's fine because of the priority order of calls to _update_normalized.
|
||||
pass
|
||||
target[k] = v
|
||||
"""Update target dict with normalized, persistence-safe keys from source."""
|
||||
for key, value in source.items():
|
||||
if key in GenParamsMerger.BLACKLISTED_KEYS:
|
||||
continue
|
||||
|
||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(key, key)
|
||||
if normalized_key not in GenParamsMerger.ALLOWED_KEYS:
|
||||
continue
|
||||
|
||||
target[normalized_key] = value
|
||||
|
||||
@@ -5,6 +5,7 @@ from .comfy import ComfyMetadataParser
|
||||
from .meta_format import MetaFormatParser
|
||||
from .automatic import AutomaticMetadataParser
|
||||
from .civitai_image import CivitaiApiMetadataParser
|
||||
from .sui_image_params import SuiImageParamsParser
|
||||
|
||||
__all__ = [
|
||||
'RecipeFormatParser',
|
||||
@@ -12,4 +13,5 @@ __all__ = [
|
||||
'MetaFormatParser',
|
||||
'AutomaticMetadataParser',
|
||||
'CivitaiApiMetadataParser',
|
||||
'SuiImageParamsParser',
|
||||
]
|
||||
|
||||
@@ -123,24 +123,39 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
if model_hash_from_hashes:
|
||||
metadata["model_hash"] = model_hash_from_hashes
|
||||
|
||||
# Extract Lora hashes in alternative format
|
||||
# Extract Lora hashes in alternative format.
|
||||
# Run unconditionally (not just as fallback) so that
|
||||
# non-empty hashes from Lora hashes fill in the gaps left
|
||||
# by empty values in the Hashes JSON dict. Some WebUI
|
||||
# builds write real hash values only to Lora hashes and
|
||||
# leave the Hashes JSON values empty.
|
||||
lora_hashes_match = re.search(self.LORA_HASHES_REGEX, params_section)
|
||||
if not hashes_match and lora_hashes_match:
|
||||
if lora_hashes_match:
|
||||
try:
|
||||
lora_hashes_str = lora_hashes_match.group(1)
|
||||
lora_hash_entries = lora_hashes_str.split(', ')
|
||||
|
||||
# Initialize hashes dict if it doesn't exist
|
||||
if "hashes" not in metadata:
|
||||
metadata["hashes"] = {}
|
||||
|
||||
|
||||
# Parse each lora hash entry (format: "name: hash")
|
||||
for entry in lora_hash_entries:
|
||||
if ': ' in entry:
|
||||
lora_name, lora_hash = entry.split(': ', 1)
|
||||
# Add as lora type in the same format as regular hashes
|
||||
metadata["hashes"][f"lora:{lora_name}"] = lora_hash.strip()
|
||||
|
||||
lora_hash = lora_hash.strip()
|
||||
if not lora_hash:
|
||||
# Skip entries without a hash value
|
||||
continue
|
||||
# Initialize hashes dict if it doesn't exist
|
||||
if "hashes" not in metadata:
|
||||
metadata["hashes"] = {}
|
||||
# Add as lora type in the same format as
|
||||
# regular hashes. Only override an
|
||||
# existing entry if its value is empty
|
||||
# (Lora hashes is the more reliable
|
||||
# source when Hashes JSON has blanks).
|
||||
key = f"lora:{lora_name}"
|
||||
existing = metadata["hashes"].get(key, "")
|
||||
if not existing:
|
||||
metadata["hashes"][key] = lora_hash
|
||||
|
||||
# Remove lora hashes from params section
|
||||
params_section = params_section.replace(lora_hashes_match.group(0), '')
|
||||
except Exception as e:
|
||||
@@ -362,6 +377,12 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
# Only process lora or hypernet types
|
||||
if not hash_key.startswith(("lora:", "hypernet:")):
|
||||
continue
|
||||
|
||||
# Skip entries without a hash value — they can't be
|
||||
# resolved via CivitAI and would only produce a
|
||||
# useless "Deleted" entry in the recipe.
|
||||
if not lora_hash:
|
||||
continue
|
||||
|
||||
lora_type, lora_name = hash_key.split(':', 1)
|
||||
|
||||
@@ -387,11 +408,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
# Try to get info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
if lora_hash:
|
||||
# If we have hash, use it for lookup
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
else:
|
||||
civitai_info = None
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Dict, Any, Union
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
from ...config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,6 +43,7 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
"height",
|
||||
"Model",
|
||||
"Model hash",
|
||||
"modelVersionIds",
|
||||
)
|
||||
return any(key in payload for key in civitai_image_fields)
|
||||
|
||||
@@ -72,7 +74,8 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
return False
|
||||
|
||||
async def parse_metadata( # type: ignore[override]
|
||||
self, user_comment, recipe_scanner=None, civitai_client=None
|
||||
self, user_comment, recipe_scanner=None, civitai_client=None,
|
||||
local_cache: dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse metadata from Civitai image format
|
||||
|
||||
@@ -80,6 +83,8 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
user_comment: The metadata from the image (dict)
|
||||
recipe_scanner: Optional recipe scanner service
|
||||
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
|
||||
local_cache: Optional dict mapping sha256/autov3 hash → scanner cache item.
|
||||
When provided, matching models skip CivitAI API calls.
|
||||
|
||||
Returns:
|
||||
Dict containing parsed recipe data
|
||||
@@ -184,8 +189,77 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
# Process standard resources array
|
||||
if "resources" in metadata and isinstance(metadata["resources"], list):
|
||||
for resource in metadata["resources"]:
|
||||
resource_type = resource.get("type", "lora")
|
||||
|
||||
# Track resources with type "model" — these are checkpoint models.
|
||||
# The resources array is the most reliable source for checkpoint
|
||||
# identification because it has an explicit type field and hash,
|
||||
# unlike modelVersionIds which is a flat list with no type info.
|
||||
if resource_type == "model":
|
||||
checkpoint_entry = {
|
||||
"id": 0,
|
||||
"modelId": 0,
|
||||
"name": resource.get("name", "Unknown Model"),
|
||||
"version": "",
|
||||
"type": resource.get("type", "model"),
|
||||
"existsLocally": False,
|
||||
"localPath": None,
|
||||
"file_name": resource.get("name", ""),
|
||||
"hash": resource.get("hash", "") or "",
|
||||
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||
"baseModel": "",
|
||||
"size": 0,
|
||||
"downloadUrl": "",
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
# Try to look up base model from the checkpoint hash
|
||||
cp_hash = checkpoint_entry.get("hash")
|
||||
if cp_hash and metadata_provider:
|
||||
local_cached = local_cache.get(cp_hash) if local_cache else None
|
||||
if local_cached:
|
||||
self._populate_entry_from_cache(
|
||||
checkpoint_entry, local_cached
|
||||
)
|
||||
bm = checkpoint_entry.get("baseModel", "")
|
||||
if bm and not result["base_model"]:
|
||||
result["base_model"] = bm
|
||||
else:
|
||||
try:
|
||||
civitai_info = (
|
||||
await metadata_provider.get_model_by_hash(
|
||||
cp_hash
|
||||
)
|
||||
)
|
||||
civitai_data, error_msg = (
|
||||
(civitai_info, None)
|
||||
if not isinstance(civitai_info, tuple)
|
||||
else civitai_info
|
||||
)
|
||||
if civitai_data and error_msg != "Model not found":
|
||||
if 'model' in civitai_data and 'name' in civitai_data['model']:
|
||||
checkpoint_entry['name'] = civitai_data['model']['name']
|
||||
checkpoint_entry['id'] = civitai_data.get('id', 0)
|
||||
checkpoint_entry['modelId'] = civitai_data.get('modelId', 0)
|
||||
if 'name' in civitai_data:
|
||||
checkpoint_entry['version'] = civitai_data['name']
|
||||
base_model = civitai_data.get('baseModel', '')
|
||||
if base_model:
|
||||
checkpoint_entry['baseModel'] = base_model
|
||||
if not result['base_model']:
|
||||
result['base_model'] = base_model
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error fetching checkpoint info for hash "
|
||||
f"{cp_hash}: {e}"
|
||||
)
|
||||
|
||||
if result["model"] is None:
|
||||
result["model"] = checkpoint_entry
|
||||
continue
|
||||
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type", "lora") == "lora":
|
||||
if resource_type == "lora":
|
||||
lora_hash = resource.get("hash", "")
|
||||
|
||||
# Try to get hash from the hashes field if not present in resource
|
||||
@@ -219,34 +293,45 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry["hash"] and metadata_provider:
|
||||
try:
|
||||
civitai_info = (
|
||||
await metadata_provider.get_model_by_hash(lora_hash)
|
||||
if lora_hash and metadata_provider:
|
||||
local_cached = local_cache.get(lora_hash) if local_cache else None
|
||||
if local_cached:
|
||||
self._populate_entry_from_cache(
|
||||
lora_entry, local_cached
|
||||
)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash,
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if "id" in lora_entry and lora_entry["id"]:
|
||||
# Track by version ID for deduplication
|
||||
if lora_entry.get("id"):
|
||||
added_loras[str(lora_entry["id"])] = len(
|
||||
result["loras"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
civitai_info = (
|
||||
await metadata_provider.get_model_by_hash(lora_hash)
|
||||
)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash,
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if "id" in lora_entry and lora_entry["id"]:
|
||||
added_loras[str(lora_entry["id"])] = len(
|
||||
result["loras"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}"
|
||||
)
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
@@ -429,6 +514,113 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process modelVersionIds from Civitai image API.
|
||||
# These are version IDs returned at root level of the API response.
|
||||
# When resources or civitaiResources are already present in metadata
|
||||
# (which they are when ?withMeta=true is passed), those sections have
|
||||
# complete hash/type information — modelVersionIds is a fallback for
|
||||
# when meta is null and only the flat ID list is available. Skipping
|
||||
# it here avoids duplicates: the same file hash often resolves to
|
||||
# different version IDs via hash lookup (resources) vs the original
|
||||
# version ID in modelVersionIds, and both paths would create entries.
|
||||
if (
|
||||
"modelVersionIds" in metadata
|
||||
and isinstance(metadata["modelVersionIds"], list)
|
||||
and not result.get("loras")
|
||||
):
|
||||
|
||||
for version_id in metadata["modelVersionIds"]:
|
||||
version_id_str = str(version_id)
|
||||
|
||||
# Skip if we've already added this LoRA by version ID
|
||||
if version_id_str in added_loras:
|
||||
continue
|
||||
|
||||
# Skip if this version ID is already the recipe's checkpoint
|
||||
# (resolved earlier from embedded resources/Model hash,
|
||||
# avoiding a duplicate CivitAI API call).
|
||||
existing_model = result.get("model")
|
||||
if existing_model and str(existing_model.get("id")) == version_id_str:
|
||||
continue
|
||||
|
||||
# Initialize lora entry with version ID
|
||||
lora_entry = {
|
||||
"id": version_id,
|
||||
"modelId": 0,
|
||||
"name": "Unknown LoRA",
|
||||
"version": "",
|
||||
"type": "lora",
|
||||
"weight": 1.0,
|
||||
"existsLocally": False,
|
||||
"thumbnailUrl": "/loras_static/images/no-preview.png",
|
||||
"baseModel": "",
|
||||
"size": 0,
|
||||
"downloadUrl": "",
|
||||
"isDeleted": False,
|
||||
}
|
||||
|
||||
# Fetch model info from Civitai
|
||||
if metadata_provider and version_id_str:
|
||||
try:
|
||||
civitai_info = (
|
||||
await metadata_provider.get_model_version_info(
|
||||
version_id_str
|
||||
)
|
||||
)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
# Not a LoRA — try as checkpoint (only if we
|
||||
# don't already have one). Reuses the same
|
||||
# civitai_info from the API call above so no
|
||||
# extra query is made.
|
||||
if result["model"] is None:
|
||||
checkpoint_entry = {
|
||||
"id": version_id,
|
||||
"modelId": 0,
|
||||
"name": "Unknown Model",
|
||||
"version": "",
|
||||
"type": "checkpoint",
|
||||
"existsLocally": False,
|
||||
"localPath": None,
|
||||
"file_name": "",
|
||||
"hash": "",
|
||||
"thumbnailUrl": (
|
||||
"/loras_static/images/no-preview.png"
|
||||
),
|
||||
"baseModel": "",
|
||||
"size": 0,
|
||||
"downloadUrl": "",
|
||||
"isDeleted": False,
|
||||
}
|
||||
cp_populated = await (
|
||||
self.populate_checkpoint_from_civitai(
|
||||
checkpoint_entry, civitai_info
|
||||
)
|
||||
)
|
||||
if cp_populated.get("modelId"):
|
||||
result["model"] = cp_populated
|
||||
continue # Not a LoRA, don't add to loras
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error fetching Civitai info for model version {version_id}: {e}"
|
||||
)
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id_str:
|
||||
added_loras[version_id_str] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# If we found LoRA hashes in the metadata but haven't already
|
||||
# populated entries for them, fall back to creating LoRAs from
|
||||
# the hashes section. Some Civitai image responses only include
|
||||
@@ -565,3 +757,41 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
|
||||
@staticmethod
|
||||
def _populate_entry_from_cache(
|
||||
entry: dict[str, Any],
|
||||
cache_item: dict[str, Any],
|
||||
) -> None:
|
||||
"""Fill a lora/checkpoint entry from a scanner cache item.
|
||||
|
||||
Avoids CivitAI API calls for models that exist locally.
|
||||
Mirrors the population logic in
|
||||
``RecipeMetadataParser.populate_lora_from_civitai()`` but operates
|
||||
entirely on cached data.
|
||||
"""
|
||||
civ = cache_item.get("civitai") or {}
|
||||
if isinstance(civ, dict):
|
||||
if civ.get("id") is not None:
|
||||
entry["id"] = civ["id"]
|
||||
if civ.get("modelId") is not None:
|
||||
entry["modelId"] = civ["modelId"]
|
||||
if civ.get("name"):
|
||||
entry["version"] = civ["name"]
|
||||
cached_name = cache_item.get("model_name")
|
||||
if cached_name:
|
||||
entry["name"] = cached_name
|
||||
entry["existsLocally"] = True
|
||||
local_path = cache_item.get("file_path")
|
||||
if local_path:
|
||||
entry["localPath"] = local_path
|
||||
sha256 = cache_item.get("sha256")
|
||||
if sha256:
|
||||
entry["hash"] = sha256
|
||||
if "preview_url" in cache_item:
|
||||
entry["thumbnailUrl"] = config.get_preview_static_url(
|
||||
cache_item["preview_url"]
|
||||
)
|
||||
base_model = cache_item.get("base_model", "")
|
||||
if base_model:
|
||||
entry["baseModel"] = base_model
|
||||
|
||||
188
py/recipes/parsers/sui_image_params.py
Normal file
188
py/recipes/parsers/sui_image_params.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Parser for SuiImage (Stable Diffusion WebUI) metadata format."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from ..base import RecipeMetadataParser
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuiImageParamsParser(RecipeMetadataParser):
|
||||
"""Parser for SuiImage metadata JSON format.
|
||||
|
||||
This format is used by some Stable Diffusion WebUI variants.
|
||||
Structure:
|
||||
{
|
||||
"sui_image_params": {
|
||||
"prompt": "...",
|
||||
"negativeprompt": "...",
|
||||
"model": "...",
|
||||
"seed": ...,
|
||||
"steps": ...,
|
||||
...
|
||||
},
|
||||
"sui_models": [
|
||||
{"name": "...", "param": "model", "hash": "..."},
|
||||
...
|
||||
],
|
||||
"sui_extra_data": {...}
|
||||
}
|
||||
"""
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the SuiImage metadata format"""
|
||||
try:
|
||||
data = json.loads(user_comment)
|
||||
return isinstance(data, dict) and 'sui_image_params' in data
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return False
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from SuiImage metadata format"""
|
||||
try:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
data = json.loads(user_comment)
|
||||
params = data.get('sui_image_params', {})
|
||||
models = data.get('sui_models', [])
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
prompt = params.get('prompt', '')
|
||||
negative_prompt = params.get('negativeprompt', '') or params.get('negative_prompt', '')
|
||||
|
||||
# Extract generation parameters
|
||||
gen_params = {}
|
||||
if prompt:
|
||||
gen_params['prompt'] = prompt
|
||||
if negative_prompt:
|
||||
gen_params['negative_prompt'] = negative_prompt
|
||||
|
||||
# Map standard parameters
|
||||
param_mapping = {
|
||||
'steps': 'steps',
|
||||
'seed': 'seed',
|
||||
'cfgscale': 'cfg_scale',
|
||||
'cfg_scale': 'cfg_scale',
|
||||
'width': 'width',
|
||||
'height': 'height',
|
||||
'sampler': 'sampler',
|
||||
'scheduler': 'scheduler',
|
||||
'model': 'model',
|
||||
'vae': 'vae',
|
||||
}
|
||||
|
||||
for src_key, dest_key in param_mapping.items():
|
||||
if src_key in params and params[src_key] is not None:
|
||||
gen_params[dest_key] = params[src_key]
|
||||
|
||||
# Add size info if available
|
||||
if 'width' in gen_params and 'height' in gen_params:
|
||||
gen_params['size'] = f"{gen_params['width']}x{gen_params['height']}"
|
||||
|
||||
# Process models - extract checkpoint and loras
|
||||
loras: List[Dict[str, Any]] = []
|
||||
checkpoint: Optional[Dict[str, Any]] = None
|
||||
|
||||
for model in models:
|
||||
model_name = model.get('name', '')
|
||||
param_type = model.get('param', '')
|
||||
model_hash = model.get('hash', '')
|
||||
|
||||
# Remove .safetensors extension for cleaner name
|
||||
clean_name = model_name.replace('.safetensors', '') if model_name else ''
|
||||
|
||||
# Check if this is a LoRA by looking at the name or param type
|
||||
is_lora = 'lora' in model_name.lower() or param_type.lower().startswith('lora')
|
||||
|
||||
if is_lora:
|
||||
lora_entry = {
|
||||
'id': 0,
|
||||
'modelId': 0,
|
||||
'name': clean_name,
|
||||
'version': '',
|
||||
'type': 'lora',
|
||||
'weight': 1.0,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': model_name,
|
||||
'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get additional info from metadata provider
|
||||
if metadata_provider and model_hash:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(
|
||||
model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash
|
||||
)
|
||||
if civitai_info:
|
||||
lora_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry, civitai_info, recipe_scanner
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching info for LoRA {clean_name}: {e}")
|
||||
|
||||
if lora_entry:
|
||||
loras.append(lora_entry)
|
||||
elif param_type == 'model' or 'lora' not in model_name.lower():
|
||||
# This is likely a checkpoint
|
||||
checkpoint_entry = {
|
||||
'id': 0,
|
||||
'modelId': 0,
|
||||
'name': clean_name,
|
||||
'version': '',
|
||||
'type': 'checkpoint',
|
||||
'hash': model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': model_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get additional info from metadata provider
|
||||
if metadata_provider and model_hash:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(
|
||||
model_hash.replace('0x', '') if model_hash.startswith('0x') else model_hash
|
||||
)
|
||||
if civitai_info:
|
||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
||||
checkpoint_entry, civitai_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching info for checkpoint {clean_name}: {e}")
|
||||
|
||||
checkpoint = checkpoint_entry
|
||||
|
||||
# Determine base model from loras or checkpoint
|
||||
base_model = None
|
||||
if loras:
|
||||
base_models = [lora.get('baseModel') for lora in loras if lora.get('baseModel')]
|
||||
if base_models:
|
||||
from collections import Counter
|
||||
base_model_counts = Counter(base_models)
|
||||
base_model = base_model_counts.most_common(1)[0][0]
|
||||
elif checkpoint and checkpoint.get('baseModel'):
|
||||
base_model = checkpoint['baseModel']
|
||||
|
||||
return {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'checkpoint': checkpoint,
|
||||
'gen_params': gen_params,
|
||||
'from_sui_image_params': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing SuiImage metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
@@ -251,7 +251,7 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
def _find_model_file(self, files):
|
||||
"""Find the appropriate model file from the files list - can be overridden by subclasses."""
|
||||
return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None)
|
||||
return next((file for file in files if file.get("type") in ("Model", "Diffusion Model") and file.get("primary") is True), None)
|
||||
|
||||
def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
"""Expose handlers for subclasses or tests."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Base infrastructure shared across recipe routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -16,12 +17,14 @@ from ..services.recipes import (
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
)
|
||||
from ..services.batch_import_service import BatchImportService
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
BatchImportHandler,
|
||||
RecipeAnalysisHandler,
|
||||
RecipeHandlerSet,
|
||||
RecipeListingHandler,
|
||||
@@ -116,7 +119,10 @@ class BaseRecipeRoutes:
|
||||
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||
civitai_client_getter = lambda: self.civitai_client
|
||||
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
standalone_mode = (
|
||||
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
|
||||
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
)
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||
@@ -190,6 +196,22 @@ class BaseRecipeRoutes:
|
||||
sharing_service=sharing_service,
|
||||
)
|
||||
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
batch_import_service = BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
batch_import = BatchImportHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
logger=logger,
|
||||
batch_import_service=batch_import_service,
|
||||
)
|
||||
|
||||
return RecipeHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
@@ -197,4 +219,5 @@ class BaseRecipeRoutes:
|
||||
management=management,
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
batch_import=batch_import,
|
||||
)
|
||||
|
||||
167
py/routes/handlers/agent_handlers.py
Normal file
167
py/routes/handlers/agent_handlers.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""HTTP route handlers for agent skill endpoints.
|
||||
|
||||
These handlers expose the :class:`AgentService` via HTTP, allowing the
|
||||
frontend to list available skills and execute them on selected models.
|
||||
Progress is reported via WebSocket broadcast.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.agent import AgentService, AgentProgressReporter
|
||||
from ...services.llm_service import LLMNotConfiguredError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentHandler:
|
||||
"""HTTP handler for agent skill operations."""
|
||||
|
||||
def __init__(self, agent_service: AgentService | None = None) -> None:
|
||||
self._agent_service = agent_service
|
||||
|
||||
async def _ensure_service(self) -> AgentService:
|
||||
if self._agent_service is None:
|
||||
self._agent_service = await AgentService.get_instance()
|
||||
return self._agent_service
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GET /api/lm/agent/skills
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_agent_skills(self, request: web.Request) -> web.Response:
|
||||
"""Return a list of available agent skills."""
|
||||
|
||||
service = await self._ensure_service()
|
||||
skills = await service.list_skills()
|
||||
return web.json_response({"skills": skills})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# POST /api/lm/agent/execute/{skill_name}
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute_agent_skill(self, request: web.Request) -> web.Response:
|
||||
"""Execute an agent skill on the provided model paths.
|
||||
|
||||
Request body::
|
||||
|
||||
{"model_paths": ["/path/to/model1.safetensors", ...], "options": {}}
|
||||
|
||||
Returns immediately with a task ID. Execution runs in the
|
||||
background; progress and completion are pushed via WebSocket
|
||||
events of type ``agent_progress``.
|
||||
"""
|
||||
|
||||
skill_name = request.match_info.get("skill_name", "")
|
||||
if not skill_name:
|
||||
return web.json_response(
|
||||
{"error": "Skill name is required"}, status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Invalid JSON body"}, status_code=400
|
||||
)
|
||||
|
||||
model_paths = body.get("model_paths", [])
|
||||
if not model_paths or not isinstance(model_paths, list):
|
||||
return web.json_response(
|
||||
{"error": "model_paths must be a non-empty array"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
service = await self._ensure_service()
|
||||
|
||||
# Validate LLM configuration early for skills that need it
|
||||
# (fail fast rather than after starting background work)
|
||||
try:
|
||||
from ...services.llm_service import LLMService
|
||||
|
||||
llm = await LLMService.get_instance()
|
||||
if not llm.is_configured():
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "LLM provider is not configured. "
|
||||
"Enable it in Settings → AI Provider.",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to check LLM configuration: %s", exc)
|
||||
|
||||
# Launch execution in the background
|
||||
progress_reporter = AgentProgressReporter()
|
||||
logger.info(
|
||||
"Agent skill '%s' starting for %d model(s) in background task",
|
||||
skill_name, len(model_paths),
|
||||
)
|
||||
|
||||
async def _run() -> None:
|
||||
logger.info("_run background task started for skill '%s'", skill_name)
|
||||
try:
|
||||
result = await service.execute_skill(
|
||||
skill_name=skill_name,
|
||||
input_data={"model_paths": model_paths},
|
||||
progress_callback=progress_reporter,
|
||||
)
|
||||
logger.info(
|
||||
"Agent skill '%s' finished: success=%s, summary='%s', errors=%s",
|
||||
skill_name, result.success, result.summary, result.errors,
|
||||
)
|
||||
except LLMNotConfiguredError as exc:
|
||||
logger.warning("Agent skill '%s' not configured: %s", skill_name, exc)
|
||||
await progress_reporter.on_progress(
|
||||
{
|
||||
"type": "agent_progress",
|
||||
"skill": skill_name,
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Agent skill '%s' failed: %s", skill_name, exc, exc_info=True)
|
||||
await progress_reporter.on_progress(
|
||||
{
|
||||
"type": "agent_progress",
|
||||
"skill": skill_name,
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
|
||||
# Fire and forget — progress comes via WebSocket
|
||||
task = asyncio.create_task(_run())
|
||||
logger.info("Agent skill '%s' background task created (id=%s)", skill_name, task)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "started",
|
||||
"skill": skill_name,
|
||||
"model_count": len(model_paths),
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# POST /api/lm/agent/cancel
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def cancel_agent_skill(self, request: web.Request) -> web.Response:
|
||||
"""Cancel a running agent skill.
|
||||
|
||||
NOTE: Cancellation is a stub for now — the AgentService processes
|
||||
models sequentially and does not yet support mid-execution
|
||||
cancellation. This endpoint exists for API completeness.
|
||||
"""
|
||||
|
||||
# TODO: implement cooperative cancellation in AgentService
|
||||
return web.json_response(
|
||||
{"status": "acknowledged", "note": "Cancellation not yet implemented"},
|
||||
status_code=200,
|
||||
)
|
||||
141
py/routes/handlers/base_model_handlers.py
Normal file
141
py/routes/handlers/base_model_handlers.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Handlers for base model related endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.civitai_base_model_service import get_civitai_base_model_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModelHandlerSet:
|
||||
"""Collection of handlers for base model operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model_service_factory: Callable[[], Any] = get_civitai_base_model_service,
|
||||
) -> None:
|
||||
self._base_model_service_factory = base_model_service_factory
|
||||
|
||||
def to_route_mapping(
|
||||
self,
|
||||
) -> Dict[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
"""Return mapping of route names to handler methods."""
|
||||
return {
|
||||
"get_base_models": self.get_base_models,
|
||||
"refresh_base_models": self.refresh_base_models,
|
||||
"get_base_model_categories": self.get_base_model_categories,
|
||||
"get_base_model_cache_status": self.get_base_model_cache_status,
|
||||
}
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get merged base models (hardcoded + remote from Civitai).
|
||||
|
||||
Query Parameters:
|
||||
refresh: If 'true', force refresh from API
|
||||
|
||||
Returns:
|
||||
JSON response with:
|
||||
- models: List of base model names
|
||||
- source: 'cache', 'api', or 'fallback'
|
||||
- last_updated: ISO timestamp
|
||||
- counts: hardcoded_count, remote_count, merged_count
|
||||
"""
|
||||
try:
|
||||
service = await self._base_model_service_factory()
|
||||
|
||||
# Check for refresh parameter
|
||||
force_refresh = request.query.get("refresh", "").lower() == "true"
|
||||
|
||||
result = await service.get_base_models(force_refresh=force_refresh)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"data": result,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_base_models: {e}")
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(e)},
|
||||
status=500,
|
||||
)
|
||||
|
||||
async def refresh_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Force refresh base models from Civitai API.
|
||||
|
||||
Returns:
|
||||
JSON response with refreshed data
|
||||
"""
|
||||
try:
|
||||
service = await self._base_model_service_factory()
|
||||
result = await service.refresh_cache()
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"data": result,
|
||||
"message": "Base models cache refreshed successfully",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in refresh_base_models: {e}")
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(e)},
|
||||
status=500,
|
||||
)
|
||||
|
||||
async def get_base_model_categories(self, request: web.Request) -> web.Response:
|
||||
"""Get categorized base models.
|
||||
|
||||
Returns:
|
||||
JSON response with categorized models
|
||||
"""
|
||||
try:
|
||||
service = await self._base_model_service_factory()
|
||||
categories = service.get_model_categories()
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"data": categories,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_base_model_categories: {e}")
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(e)},
|
||||
status=500,
|
||||
)
|
||||
|
||||
async def get_base_model_cache_status(self, request: web.Request) -> web.Response:
|
||||
"""Get cache status for base models.
|
||||
|
||||
Returns:
|
||||
JSON response with cache status
|
||||
"""
|
||||
try:
|
||||
service = await self._base_model_service_factory()
|
||||
status = service.get_cache_status()
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"data": status,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_base_model_cache_status: {e}")
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(e)},
|
||||
status=500,
|
||||
)
|
||||
417
py/routes/handlers/hf_handlers.py
Normal file
417
py/routes/handlers/hf_handlers.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""Handlers for Hugging Face model listing and download.
|
||||
|
||||
Minimal MVP implementation — uses direct HTTP to the HF API for file
|
||||
listing and the project's existing aiohttp-based Downloader for
|
||||
downloading. No huggingface_hub dependency required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.downloader import (
|
||||
DownloadProgress,
|
||||
get_downloader,
|
||||
)
|
||||
from ...services.aria2_downloader import Aria2Downloader
|
||||
from ...services.settings_manager import get_settings_manager
|
||||
from ...services.service_registry import ServiceRegistry
|
||||
from ...services.websocket_manager import ws_manager
|
||||
from ...utils.constants import MODEL_FILE_EXTENSIONS
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
from ...utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL_CLASS = LoraMetadata
|
||||
_DEFAULT_SCANNER_GETTER = "get_lora_scanner"
|
||||
|
||||
# Shared aiohttp session for HF API calls (created on first use)
|
||||
_hf_api_session: aiohttp.ClientSession | None = None
|
||||
|
||||
|
||||
async def _get_hf_api_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared aiohttp session for HF API calls."""
|
||||
global _hf_api_session # needed because we reassign the module-level name
|
||||
if _hf_api_session is None or _hf_api_session.closed:
|
||||
_hf_api_session = aiohttp.ClientSession(
|
||||
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
)
|
||||
return _hf_api_session
|
||||
|
||||
|
||||
async def close_hf_api_session() -> None:
|
||||
"""Close the shared HF API session, if it was ever created."""
|
||||
global _hf_api_session
|
||||
if _hf_api_session is not None and not _hf_api_session.closed:
|
||||
await _hf_api_session.close()
|
||||
_hf_api_session = None
|
||||
|
||||
|
||||
def _infer_model_type(model_root: str) -> tuple[Any, str]:
|
||||
"""Determine model class and scanner by matching ``model_root`` against the
|
||||
configured root paths for each model type (from ``Config``).
|
||||
|
||||
The ``model_root`` value comes from the frontend's model-root dropdown,
|
||||
which is populated from the current page's scanner roots. By checking
|
||||
which scanner's root list it belongs to, we avoid fragile heuristics
|
||||
like substring-matching path names.
|
||||
"""
|
||||
norm = os.path.normpath(model_root).replace(os.sep, "/")
|
||||
|
||||
# LoRA roots
|
||||
for p in (config.loras_roots or []) + (config.extra_loras_roots or []):
|
||||
if os.path.normpath(p).replace(os.sep, "/") == norm:
|
||||
return LoraMetadata, "get_lora_scanner"
|
||||
|
||||
# Checkpoint / UNet roots
|
||||
for p in (
|
||||
(config.checkpoints_roots or [])
|
||||
+ (config.extra_checkpoints_roots or [])
|
||||
+ (config.unet_roots or [])
|
||||
+ (config.extra_unet_roots or [])
|
||||
):
|
||||
if os.path.normpath(p).replace(os.sep, "/") == norm:
|
||||
return CheckpointMetadata, "get_checkpoint_scanner"
|
||||
|
||||
# Embedding roots
|
||||
for p in (config.embeddings_roots or []) + (config.extra_embeddings_roots or []):
|
||||
if os.path.normpath(p).replace(os.sep, "/") == norm:
|
||||
return EmbeddingMetadata, "get_embedding_scanner"
|
||||
|
||||
# Fallback — should not happen in normal use
|
||||
logger.warning(
|
||||
"Could not determine model type for root '%s'; defaulting to LoRA",
|
||||
model_root,
|
||||
)
|
||||
return _DEFAULT_MODEL_CLASS, _DEFAULT_SCANNER_GETTER
|
||||
|
||||
|
||||
async def _save_hf_metadata(dest_path: str, repo: str, model_root: str) -> None:
|
||||
"""Create a proper .metadata.json and add the model to the scanner cache.
|
||||
|
||||
Uses ``MetadataManager.create_default_metadata()`` which computes the
|
||||
SHA256 hash, extracts safetensors header metadata (base_model), and
|
||||
produces a fully-populated ``LoraMetadata`` (or ``CheckpointMetadata`` /
|
||||
``EmbeddingMetadata``) object. We then overlay HF-specific fields and
|
||||
register the model in the in-memory scanner cache so it appears
|
||||
immediately without a full filesystem walk.
|
||||
"""
|
||||
try:
|
||||
hf_url = f"https://huggingface.co/{repo}"
|
||||
model_class, scanner_getter_name = _infer_model_type(model_root)
|
||||
|
||||
# 1. Create proper metadata (computes SHA256, reads safetensors headers)
|
||||
metadata = await MetadataManager.create_default_metadata(
|
||||
dest_path, model_class=model_class
|
||||
)
|
||||
if metadata is None:
|
||||
logger.warning("create_default_metadata returned None for %s", dest_path)
|
||||
return
|
||||
|
||||
# 2. Overlay HF-specific fields
|
||||
metadata._unknown_fields["hf_url"] = hf_url
|
||||
metadata.from_civitai = False # HF models are not from CivitAI
|
||||
|
||||
# 3. Save metadata atomically
|
||||
await MetadataManager.save_metadata(dest_path, metadata)
|
||||
logger.info("Saved HF metadata (with hf_url) for %s", dest_path)
|
||||
|
||||
# 4. Determine relative folder path for cache
|
||||
# model_root is an absolute path; dest_path is under it
|
||||
folder = ""
|
||||
if os.path.isabs(model_root) and dest_path.startswith(model_root):
|
||||
rel = os.path.relpath(os.path.dirname(dest_path), model_root)
|
||||
folder = rel.replace(os.sep, "/") if rel != "." else ""
|
||||
|
||||
# 5. Add to scanner cache (same as CivitAI's _execute_download does)
|
||||
scanner_getter = getattr(ServiceRegistry, scanner_getter_name, None)
|
||||
if scanner_getter is not None:
|
||||
scanner = await scanner_getter()
|
||||
if scanner is not None:
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict["hf_url"] = hf_url
|
||||
await scanner.add_model_to_cache(metadata_dict, folder)
|
||||
logger.info("Added %s to scanner cache (folder=%s)", dest_path, folder)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to save HF metadata for %s: %s", dest_path, exc)
|
||||
|
||||
|
||||
class HfHandler:
|
||||
"""Handle Hugging Face model browsing and download."""
|
||||
|
||||
async def get_hf_repo_files(self, request: web.Request) -> web.Response:
|
||||
"""List model-weight files from a HF repo with real file sizes.
|
||||
|
||||
Uses the HF tree API endpoint which returns accurate file sizes
|
||||
(including LFS-tracked files), unlike the model info endpoint.
|
||||
"""
|
||||
repo = request.query.get("repo", "").strip()
|
||||
if not repo or "/" not in repo:
|
||||
return web.json_response(
|
||||
{"error": "Missing or invalid 'repo' parameter (expected user/repo)"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
url = f"https://huggingface.co/api/models/{repo}/tree/main"
|
||||
|
||||
try:
|
||||
session = await _get_hf_api_session()
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
return web.json_response(
|
||||
{"error": f"Repo '{repo}' not found"}, status=404
|
||||
)
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
return web.json_response(
|
||||
{"error": f"HF API error {resp.status}: {text[:200]}"},
|
||||
status=resp.status,
|
||||
)
|
||||
tree: list[dict[str, Any]] = await resp.json()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch HF repo files: %s", exc)
|
||||
return web.json_response({"error": str(exc)}, status=502)
|
||||
|
||||
files: list[dict[str, Any]] = []
|
||||
for entry in tree:
|
||||
path: str = entry.get("path", "")
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext not in MODEL_FILE_EXTENSIONS:
|
||||
continue
|
||||
size = entry.get("size", 0) or 0
|
||||
if size == 0 and "lfs" in entry:
|
||||
size = entry["lfs"].get("size", 0) or 0
|
||||
files.append({
|
||||
"filename": path,
|
||||
"size": size,
|
||||
})
|
||||
|
||||
files.sort(key=lambda f: f["size"], reverse=True)
|
||||
return web.json_response(files)
|
||||
|
||||
async def download_hf_model(self, request: web.Request) -> web.Response:
|
||||
"""Download a single file from Hugging Face into the model directory.
|
||||
|
||||
POST JSON body::
|
||||
|
||||
{
|
||||
"repo": "dx8152/Flux2-Klein-9B-Consistency",
|
||||
"filename": "Flux2-Klein-9B-consistency-V2.safetensors",
|
||||
"revision": "main",
|
||||
"model_root": "loras",
|
||||
"relative_path": "",
|
||||
"use_default_paths": false,
|
||||
"download_id": "optional-batch-id"
|
||||
}
|
||||
|
||||
If ``download_id`` is provided, real-time progress (bytes, speed,
|
||||
percentage) is broadcast via the WebSocket progress system, matching
|
||||
the CivitAI download experience.
|
||||
|
||||
Respects the ``download_backend`` setting (``aria2`` or ``default``).
|
||||
"""
|
||||
try:
|
||||
payload: dict[str, Any] = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON"}, status=400)
|
||||
|
||||
repo = (payload.get("repo") or "").strip()
|
||||
filename = (payload.get("filename") or "").strip()
|
||||
revision = (payload.get("revision") or "main").strip()
|
||||
model_root = (payload.get("model_root") or "").strip()
|
||||
relative_path = (payload.get("relative_path") or "").strip()
|
||||
use_default_paths = bool(payload.get("use_default_paths", False))
|
||||
download_id: str | None = payload.get("download_id")
|
||||
|
||||
logger.info(
|
||||
"download_hf_model: repo=%s file=%s root=%s download_id=%s",
|
||||
repo, filename, model_root, download_id,
|
||||
)
|
||||
|
||||
if not repo or not filename:
|
||||
return web.json_response(
|
||||
{"error": "Missing required fields: 'repo' and 'filename'"}, status=400
|
||||
)
|
||||
|
||||
# Validate repo format — must be user/repo_name
|
||||
if repo.count("/") != 1 or not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo):
|
||||
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
|
||||
author, repo_name = repo.split("/", 1)
|
||||
if ".." in (author, repo_name) or "." in (author, repo_name):
|
||||
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
|
||||
|
||||
# Validate filename — must not contain path separators or ..
|
||||
if "/" in filename or "\\" in filename or ".." in filename:
|
||||
return web.json_response({"error": "Invalid filename"}, status=400)
|
||||
|
||||
# Validate relative_path — must not be absolute or escape base directory
|
||||
if relative_path:
|
||||
if os.path.isabs(relative_path):
|
||||
return web.json_response({"error": "relative_path must not be absolute"}, status=400)
|
||||
if ".." in relative_path.split("/") or "\\" in relative_path:
|
||||
return web.json_response({"error": "Invalid relative_path"}, status=400)
|
||||
|
||||
# Validate model_root — must not contain path traversal
|
||||
if not os.path.isabs(model_root):
|
||||
# For relative model_root, check it doesn't escape
|
||||
resolved_model_root = os.path.realpath(
|
||||
os.path.join(os.getcwd(), "models", model_root)
|
||||
)
|
||||
else:
|
||||
resolved_model_root = os.path.realpath(model_root)
|
||||
|
||||
# Verify model_root is within a configured scanner root
|
||||
allowed_roots = set()
|
||||
for root_list in (
|
||||
config.loras_roots or [],
|
||||
config.extra_loras_roots or [],
|
||||
config.checkpoints_roots or [],
|
||||
config.extra_checkpoints_roots or [],
|
||||
config.unet_roots or [],
|
||||
config.extra_unet_roots or [],
|
||||
config.embeddings_roots or [],
|
||||
config.extra_embeddings_roots or [],
|
||||
):
|
||||
for r in root_list:
|
||||
allowed_roots.add(os.path.realpath(r))
|
||||
|
||||
if not any(resolved_model_root == root or resolved_model_root.startswith(root + os.sep) for root in allowed_roots):
|
||||
logger.warning("Invalid model_root rejected: %s", model_root)
|
||||
return web.json_response({"error": f"Invalid model_root: {model_root}"}, status=400)
|
||||
|
||||
base_dir = resolved_model_root
|
||||
|
||||
if use_default_paths:
|
||||
target_dir = os.path.join(base_dir, "huggingface", author, repo_name)
|
||||
elif relative_path:
|
||||
target_dir = os.path.join(base_dir, relative_path)
|
||||
else:
|
||||
target_dir = base_dir
|
||||
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
dest_path = os.path.join(target_dir, filename)
|
||||
|
||||
# Resolve symlinks and check for path traversal escape
|
||||
real_dest = os.path.realpath(dest_path)
|
||||
real_base = os.path.realpath(target_dir)
|
||||
if not real_dest.startswith(real_base + os.sep):
|
||||
logger.warning("Path traversal blocked: %s -> %s", dest_path, real_dest)
|
||||
return web.json_response({"error": "Path traversal detected"}, status=400)
|
||||
|
||||
# Check if already exists (simple skip)
|
||||
if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0:
|
||||
logger.info("download_hf_model: file already exists, skipping — %s", dest_path)
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"File already exists: {dest_path}",
|
||||
"path": dest_path,
|
||||
})
|
||||
|
||||
# Build HF resolve URL
|
||||
resolve_url = (
|
||||
f"https://huggingface.co/{repo}/resolve/{revision}/{filename}"
|
||||
)
|
||||
|
||||
# Set up progress callback if download_id is provided
|
||||
progress_callback = None
|
||||
if download_id:
|
||||
|
||||
async def _progress_callback(
|
||||
progress: float | DownloadProgress,
|
||||
snapshot: DownloadProgress | None = None,
|
||||
) -> None:
|
||||
percent = 0.0
|
||||
metrics = snapshot if isinstance(snapshot, DownloadProgress) else None
|
||||
|
||||
if isinstance(progress, DownloadProgress):
|
||||
percent = progress.percent_complete
|
||||
metrics = progress
|
||||
elif isinstance(snapshot, DownloadProgress):
|
||||
percent = snapshot.percent_complete
|
||||
else:
|
||||
percent = float(progress)
|
||||
|
||||
broadcast: dict[str, Any] = {
|
||||
"status": "progress",
|
||||
"progress": round(percent),
|
||||
}
|
||||
if metrics:
|
||||
broadcast["bytes_downloaded"] = metrics.bytes_downloaded
|
||||
broadcast["total_bytes"] = metrics.total_bytes
|
||||
broadcast["bytes_per_second"] = metrics.bytes_per_second
|
||||
|
||||
await ws_manager.broadcast_download_progress(download_id, broadcast)
|
||||
|
||||
progress_callback = _progress_callback
|
||||
|
||||
# Respect download backend setting (aria2 vs default)
|
||||
download_backend = (
|
||||
get_settings_manager().get("download_backend", "default")
|
||||
)
|
||||
|
||||
if download_backend == "aria2":
|
||||
aria2 = await Aria2Downloader.get_instance()
|
||||
aid = download_id or f"hf_{repo}_{filename}"
|
||||
try:
|
||||
hf_success, hf_result = await aria2.download_file(
|
||||
url=resolve_url,
|
||||
save_path=dest_path,
|
||||
download_id=aid,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if hf_success:
|
||||
await _save_hf_metadata(dest_path, repo, model_root)
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Downloaded to {dest_path}",
|
||||
"path": dest_path,
|
||||
})
|
||||
else:
|
||||
return web.json_response(
|
||||
{"success": False, "error": hf_result or "aria2 download failed"},
|
||||
status=500,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("HF download (aria2) failed: %s", exc)
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc)}, status=500
|
||||
)
|
||||
|
||||
# Default: use built-in aiohttp Downloader
|
||||
downloader = await get_downloader()
|
||||
try:
|
||||
success, result = await downloader.download_file(
|
||||
url=resolve_url,
|
||||
save_path=dest_path,
|
||||
use_auth=False,
|
||||
allow_resume=True,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if success:
|
||||
await _save_hf_metadata(dest_path, repo, model_root)
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Downloaded to {result}",
|
||||
"path": result,
|
||||
})
|
||||
else:
|
||||
return web.json_response(
|
||||
{"success": False, "error": result or "Download failed"},
|
||||
status=500,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("HF download failed: %s", exc)
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc)}, status=500
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,9 +16,14 @@ import jinja2
|
||||
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.connectivity_guard import (
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
is_expected_offline_error,
|
||||
)
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelMoveService
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
from ...services.service_registry import ServiceRegistry
|
||||
from ...services.settings_manager import SettingsManager, get_settings_manager
|
||||
from ...services.tag_update_service import TagUpdateService
|
||||
from ...services.use_cases import (
|
||||
@@ -32,6 +37,7 @@ from ...services.use_cases import (
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.download_queue_service import DownloadQueueService
|
||||
from ...services.errors import RateLimitError, ResourceNotFoundError
|
||||
from ...utils.civitai_utils import resolve_license_payload
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
@@ -64,7 +70,6 @@ class ModelPageView:
|
||||
self._settings = settings_service
|
||||
self._server_i18n = server_i18n
|
||||
self._logger = logger
|
||||
self._app_version = self._get_app_version()
|
||||
|
||||
def _load_supporters(self) -> dict:
|
||||
"""Load supporters data from JSON file."""
|
||||
@@ -155,7 +160,7 @@ class ModelPageView:
|
||||
"request": request,
|
||||
"folders": [],
|
||||
"t": self._server_i18n.get_translation,
|
||||
"version": self._app_version,
|
||||
"version": self._get_app_version(),
|
||||
}
|
||||
|
||||
if not is_initializing:
|
||||
@@ -198,11 +203,17 @@ class ModelListingHandler:
|
||||
result = await self._service.get_paginated_data(**params)
|
||||
|
||||
format_start = time.perf_counter()
|
||||
formatted_raw = [
|
||||
await self._service.format_response(entry)
|
||||
for entry in result["items"]
|
||||
]
|
||||
# Filter out None entries returned for corrupted cache rows (issue #730).
|
||||
# Note: "total" intentionally remains the pre-filter count to reflect
|
||||
# the true number of models in the cache; corrupted entries are rare
|
||||
# and adjusting total would cause pagination drift on every page.
|
||||
formatted_items = [item for item in formatted_raw if item is not None]
|
||||
formatted_result = {
|
||||
"items": [
|
||||
await self._service.format_response(item)
|
||||
for item in result["items"]
|
||||
],
|
||||
"items": formatted_items,
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
@@ -224,6 +235,48 @@ class ModelListingHandler:
|
||||
)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def get_excluded_models(self, request: web.Request) -> web.Response:
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
params = self._parse_common_params(request)
|
||||
# group_by_model is meaningless for excluded view; strip it
|
||||
params.pop("group_by_model", None)
|
||||
result = await self._service.get_excluded_paginated_data(**params)
|
||||
|
||||
format_start = time.perf_counter()
|
||||
formatted_raw = [
|
||||
await self._service.format_response(entry)
|
||||
for entry in result["items"]
|
||||
]
|
||||
# Filter out None entries returned for corrupted cache rows (issue #730).
|
||||
# "total" stays at the pre-filter count; see get_models for rationale.
|
||||
formatted_items = [item for item in formatted_raw if item is not None]
|
||||
formatted_result = {
|
||||
"items": formatted_items,
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
"total_pages": result["total_pages"],
|
||||
}
|
||||
format_duration = time.perf_counter() - format_start
|
||||
|
||||
duration = time.perf_counter() - start_time
|
||||
self._logger.debug(
|
||||
"Request for %s/excluded took %.3fs (formatting: %.3fs)",
|
||||
self._service.model_type,
|
||||
duration,
|
||||
format_duration,
|
||||
)
|
||||
return web.json_response(formatted_result)
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error retrieving excluded %ss: %s",
|
||||
self._service.model_type,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def _parse_common_params(self, request: web.Request) -> Dict:
|
||||
page = int(request.query.get("page", "1"))
|
||||
page_size = min(int(request.query.get("page_size", "20")), 100)
|
||||
@@ -261,6 +314,15 @@ class ModelListingHandler:
|
||||
for tag in exclude_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
auto_tag_filters: Dict[str, str] = {}
|
||||
for tag in request.query.getall("auto_tag_include", []):
|
||||
if tag:
|
||||
auto_tag_filters[tag] = "include"
|
||||
for tag in request.query.getall("auto_tag_exclude", []):
|
||||
if tag:
|
||||
auto_tag_filters[tag] = "exclude"
|
||||
|
||||
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
|
||||
|
||||
search_options = {
|
||||
@@ -309,6 +371,26 @@ class ModelListingHandler:
|
||||
else:
|
||||
allow_selling_generated_content = None # None means no filter applied
|
||||
|
||||
# Name pattern filters for LoRA Pool
|
||||
name_pattern_include = request.query.getall("name_pattern_include", [])
|
||||
name_pattern_exclude = request.query.getall("name_pattern_exclude", [])
|
||||
name_pattern_use_regex = (
|
||||
request.query.get("name_pattern_use_regex", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Group-by-model flag: deduplicate versions sharing the same civitai modelId
|
||||
group_by_model = (
|
||||
request.query.get("group_by_model", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# View-local-versions filter: show all local versions of a specific model
|
||||
civitai_model_id = request.query.get("civitai_model_id")
|
||||
if civitai_model_id is not None:
|
||||
try:
|
||||
civitai_model_id = int(civitai_model_id)
|
||||
except (TypeError, ValueError):
|
||||
civitai_model_id = None
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
@@ -320,6 +402,7 @@ class ModelListingHandler:
|
||||
"fuzzy_search": fuzzy_search,
|
||||
"base_models": base_models,
|
||||
"tags": tag_filters,
|
||||
"auto_tags": auto_tag_filters,
|
||||
"tag_logic": tag_logic,
|
||||
"search_options": search_options,
|
||||
"hash_filters": hash_filters,
|
||||
@@ -328,6 +411,11 @@ class ModelListingHandler:
|
||||
"credit_required": credit_required,
|
||||
"allow_selling_generated_content": allow_selling_generated_content,
|
||||
"model_types": model_types,
|
||||
"name_pattern_include": name_pattern_include,
|
||||
"name_pattern_exclude": name_pattern_exclude,
|
||||
"name_pattern_use_regex": name_pattern_use_regex,
|
||||
"group_by_model": group_by_model,
|
||||
"civitai_model_id": civitai_model_id,
|
||||
**self._parse_specific_params(request),
|
||||
}
|
||||
|
||||
@@ -382,6 +470,21 @@ class ModelManagementHandler:
|
||||
self._logger.error("Error excluding model: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def unexclude_model(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.Response(text="Model path is required", status=400)
|
||||
|
||||
result = await self._lifecycle_service.unexclude_model(file_path)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error restoring model: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
@@ -440,9 +543,19 @@ class ModelManagementHandler:
|
||||
if not success:
|
||||
return web.json_response({"success": False, "error": error})
|
||||
|
||||
formatted_metadata = await self._service.format_response(model_data)
|
||||
return web.json_response({"success": True, "metadata": formatted_metadata})
|
||||
formatted = await self._service.format_response(model_data)
|
||||
if formatted is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model entry is corrupted (missing file_path)"},
|
||||
status=500,
|
||||
)
|
||||
return web.json_response({"success": True, "metadata": formatted})
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -489,6 +602,11 @@ class ModelManagementHandler:
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -703,7 +821,7 @@ class ModelManagementHandler:
|
||||
|
||||
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
||||
|
||||
await self._metadata_sync.save_metadata_updates(
|
||||
updated_metadata = await self._metadata_sync.save_metadata_updates(
|
||||
file_path=file_path,
|
||||
updates=metadata_updates,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
@@ -714,7 +832,12 @@ class ModelManagementHandler:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({"success": True})
|
||||
from ...services.auto_tag_service import extract_auto_tags
|
||||
auto_tags = extract_auto_tags(updated_metadata)
|
||||
|
||||
return web.json_response(
|
||||
{"success": True, "auto_tags": auto_tags}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
@@ -731,14 +854,16 @@ class ModelManagementHandler:
|
||||
if not isinstance(new_tags, list):
|
||||
return web.Response(text="Tags must be a list", status=400)
|
||||
|
||||
tags = await self._tag_update_service.add_tags(
|
||||
tags, auto_tags = await self._tag_update_service.add_tags(
|
||||
file_path=file_path,
|
||||
new_tags=new_tags,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "tags": tags})
|
||||
return web.json_response(
|
||||
{"success": True, "tags": tags, "auto_tags": auto_tags}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
@@ -849,7 +974,7 @@ class ModelQueryHandler:
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
if limit < 1 or limit > 100:
|
||||
if limit < 0 or limit > 100:
|
||||
limit = 20
|
||||
base_models = await self._service.get_base_models(limit)
|
||||
return web.json_response({"success": True, "base_models": base_models})
|
||||
@@ -981,10 +1106,12 @@ class ModelQueryHandler:
|
||||
# Sort: originals first, copies last
|
||||
sorted_models = self._sort_duplicate_group(filtered)
|
||||
|
||||
# Format response
|
||||
# Format response, filtering out corrupted entries (issue #730)
|
||||
group = {"hash": sha256, "models": []}
|
||||
for model in sorted_models:
|
||||
group["models"].append(await self._service.format_response(model))
|
||||
formatted = await self._service.format_response(model)
|
||||
if formatted is not None:
|
||||
group["models"].append(formatted)
|
||||
|
||||
# Only include groups with 2+ models after filtering
|
||||
if len(group["models"]) > 1:
|
||||
@@ -1085,6 +1212,12 @@ class ModelQueryHandler:
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
settings = get_settings_manager()
|
||||
if settings.get("lora_syntax_format", "legacy") == "full":
|
||||
return web.json_response(
|
||||
{"success": True, "conflicts": [], "count": 0}
|
||||
)
|
||||
|
||||
duplicates = self._service.find_duplicate_filenames()
|
||||
result = []
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
@@ -1095,9 +1228,9 @@ class ModelQueryHandler:
|
||||
(m for m in cache.raw_data if m["file_path"] == path), None
|
||||
)
|
||||
if model:
|
||||
group["models"].append(
|
||||
await self._service.format_response(model)
|
||||
)
|
||||
formatted = await self._service.format_response(model)
|
||||
if formatted is not None:
|
||||
group["models"].append(formatted)
|
||||
hash_val = self._service.scanner.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self._service.get_path_by_hash(hash_val)
|
||||
@@ -1107,9 +1240,9 @@ class ModelQueryHandler:
|
||||
None,
|
||||
)
|
||||
if main_model:
|
||||
group["models"].insert(
|
||||
0, await self._service.format_response(main_model)
|
||||
)
|
||||
formatted = await self._service.format_response(main_model)
|
||||
if formatted is not None:
|
||||
group["models"].insert(0, formatted)
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
return web.json_response(
|
||||
@@ -1173,6 +1306,14 @@ class ModelQueryHandler:
|
||||
license_flags = (model_data or {}).get("license_flags")
|
||||
if license_flags is not None:
|
||||
response_payload["license_flags"] = int(license_flags)
|
||||
# Include the user's license icon style preference so the
|
||||
# ComfyUI tooltip can pick the right set without a separate
|
||||
# API call.
|
||||
try:
|
||||
settings = get_settings_manager()
|
||||
response_payload["use_new_license_icons"] = settings.get("use_new_license_icons", True)
|
||||
except Exception:
|
||||
pass
|
||||
return web.json_response(response_payload)
|
||||
return web.json_response(
|
||||
{
|
||||
@@ -1268,8 +1409,11 @@ class ModelQueryHandler:
|
||||
async def get_relative_paths(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
search = request.query.get("search", "").strip()
|
||||
limit = min(int(request.query.get("limit", "15")), 50)
|
||||
matching_paths = await self._service.search_relative_paths(search, limit)
|
||||
limit = min(int(request.query.get("limit", "15")), 100)
|
||||
offset = max(0, int(request.query.get("offset", "0")))
|
||||
matching_paths = await self._service.search_relative_paths(
|
||||
search, limit, offset
|
||||
)
|
||||
return web.json_response(
|
||||
{"success": True, "relative_paths": matching_paths}
|
||||
)
|
||||
@@ -1371,6 +1515,21 @@ class ModelDownloadHandler:
|
||||
)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
|
||||
async def skip_download_get(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Download ID is required"}, status=400
|
||||
)
|
||||
result = await self._download_coordinator.skip_download(download_id)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error skipping download via GET: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
@@ -1451,6 +1610,291 @@ class ModelDownloadHandler:
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Download queue / history handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_download_queue(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
service = await DownloadQueueService.get_instance()
|
||||
queue = await service.get_queue()
|
||||
stats = await service.get_stats()
|
||||
return web.json_response({"success": True, "queue": queue, "stats": stats})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error getting download queue: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def add_to_download_queue(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
import uuid
|
||||
|
||||
download_id = request.query.get("download_id") or str(uuid.uuid4())
|
||||
model_id_str = request.query.get("model_id")
|
||||
model_version_id_str = request.query.get("model_version_id")
|
||||
model_name = request.query.get("model_name", "")
|
||||
version_name = request.query.get("version_name", "")
|
||||
thumbnail_url = request.query.get("thumbnail_url", "")
|
||||
source = request.query.get("source")
|
||||
file_params_json = request.query.get("file_params")
|
||||
|
||||
model_id = int(model_id_str) if model_id_str else None
|
||||
model_version_id = int(model_version_id_str) if model_version_id_str else None
|
||||
file_params = json.loads(file_params_json) if file_params_json else None
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
item = await service.add_to_queue(
|
||||
download_id=download_id,
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
model_name=model_name,
|
||||
version_name=version_name,
|
||||
thumbnail_url=thumbnail_url,
|
||||
source=source,
|
||||
file_params=file_params,
|
||||
)
|
||||
return web.json_response({"success": True, "item": item})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error adding to download queue: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def remove_from_download_queue(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "download_id is required"}, status=400
|
||||
)
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
removed = await service.remove_from_queue(download_id)
|
||||
return web.json_response({"success": removed})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error removing from download queue: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def move_queue_item_to_top(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "download_id is required"}, status=400
|
||||
)
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
moved = await service.move_to_top(download_id)
|
||||
return web.json_response({"success": moved})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error moving queue item to top: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def move_queue_item_to_end(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "download_id is required"}, status=400
|
||||
)
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
moved = await service.move_to_end(download_id)
|
||||
return web.json_response({"success": moved})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error moving queue item to end: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def clear_download_queue(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
status_filter = request.query.get("status") or None
|
||||
service = await DownloadQueueService.get_instance()
|
||||
cleared = await service.clear_queue(status_filter=status_filter)
|
||||
return web.json_response({"success": True, "cleared": cleared})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error clearing download queue: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_download_history(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
limit = min(int(request.query.get("limit", "50")), 500)
|
||||
offset = int(request.query.get("offset", "0"))
|
||||
status_filter = request.query.get("status") or None
|
||||
service = await DownloadQueueService.get_instance()
|
||||
result = await service.get_history(
|
||||
limit=limit, offset=offset, status_filter=status_filter
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"items": result["items"],
|
||||
"total": result["total"],
|
||||
"limit": result["limit"],
|
||||
"offset": result["offset"],
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error getting download history: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def clear_download_history(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
status_filter = request.query.get("status") or None
|
||||
service = await DownloadQueueService.get_instance()
|
||||
cleared = await service.clear_history(status_filter=status_filter)
|
||||
return web.json_response({"success": True, "cleared": cleared})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error clearing download history: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def delete_download_history_item(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
item_id = int(request.query.get("id", "0"))
|
||||
if not item_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "id is required"}, status=400
|
||||
)
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
deleted = await service.delete_history_item(item_id)
|
||||
return web.json_response({"success": deleted})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error deleting download history item: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def retry_download_from_history(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
item_id = int(request.query.get("id", "0"))
|
||||
if not item_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "id is required"}, status=400
|
||||
)
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
item = await service.retry_from_history(item_id)
|
||||
if item is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "History item not found or not retryable"},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response({"success": True, "item": item})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error retrying download from history: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def retry_all_failed_downloads(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
service = await DownloadQueueService.get_instance()
|
||||
retry_count = await service.retry_all_failed()
|
||||
return web.json_response({"success": True, "retry_count": retry_count})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error retrying all failed downloads: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def complete_download_in_queue(self, request: web.Request) -> web.Response:
|
||||
"""Atomically move a download from queue to history with terminal status."""
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "download_id is required"}, status=400
|
||||
)
|
||||
status = request.query.get("status", "completed")
|
||||
error = request.query.get("error")
|
||||
file_path = request.query.get("file_path")
|
||||
try:
|
||||
bytes_downloaded = int(request.query.get("bytes_downloaded", "0"))
|
||||
except (TypeError, ValueError):
|
||||
bytes_downloaded = 0
|
||||
total_bytes_raw = request.query.get("total_bytes")
|
||||
total_bytes = int(total_bytes_raw) if total_bytes_raw else None
|
||||
completed_at_raw = request.query.get("completed_at")
|
||||
completed_at = float(completed_at_raw) if completed_at_raw else None
|
||||
|
||||
service = await DownloadQueueService.get_instance()
|
||||
item = await service.complete_download(
|
||||
download_id=download_id,
|
||||
status=status,
|
||||
error=error,
|
||||
file_path=file_path,
|
||||
bytes_downloaded=bytes_downloaded,
|
||||
total_bytes=total_bytes,
|
||||
completed_at=completed_at,
|
||||
)
|
||||
if item is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Download not found in queue"}, status=404
|
||||
)
|
||||
return web.json_response({"success": True, "item": item})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error completing download: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_download_stats(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
service = await DownloadQueueService.get_instance()
|
||||
stats = await service.get_stats()
|
||||
return web.json_response({"success": True, "stats": stats})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error getting download stats: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def update_download_queue_status(self, request: web.Request) -> web.Response:
|
||||
"""Update the status of a queue item (non-terminal transitions).
|
||||
|
||||
Supported transitions include ``queued → downloading``,
|
||||
``downloading → paused``, ``paused → downloading``, etc.
|
||||
Terminal transitions (``completed``, ``failed``, ``canceled``)
|
||||
should use ``complete_download_in_queue`` instead.
|
||||
"""
|
||||
try:
|
||||
download_id = request.query.get("download_id")
|
||||
status = request.query.get("status")
|
||||
if not download_id or not status:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "download_id and status are required",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
service = await DownloadQueueService.get_instance()
|
||||
updated = await service.update_status(download_id, status)
|
||||
if not updated:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Download not found in queue"},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response({"success": True})
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error updating download queue status: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelCivitaiHandler:
|
||||
"""CivitAI integration endpoints."""
|
||||
@@ -1492,7 +1936,9 @@ class ModelCivitaiHandler:
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc
|
||||
"Error in fetch_all_civitai for %ss: %s",
|
||||
self._service.model_type, exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
@@ -1519,6 +1965,20 @@ class ModelCivitaiHandler:
|
||||
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
version_index = cache.version_index
|
||||
downloaded_version_ids: set[int] = set()
|
||||
try:
|
||||
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||
downloaded_version_ids = set(
|
||||
await history_service.get_downloaded_version_ids(
|
||||
self._service.model_type,
|
||||
model_id,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.debug(
|
||||
"Failed to load download history for CivitAI versions: %s",
|
||||
exc,
|
||||
)
|
||||
|
||||
for version in versions:
|
||||
version_id = None
|
||||
@@ -1535,6 +1995,9 @@ class ModelCivitaiHandler:
|
||||
else None
|
||||
)
|
||||
version["existsLocally"] = cache_entry is not None
|
||||
version["hasBeenDownloaded"] = (
|
||||
version_id in downloaded_version_ids if version_id is not None else False
|
||||
)
|
||||
if cache_entry and isinstance(cache_entry, Mapping):
|
||||
local_path = cache_entry.get("file_path")
|
||||
if local_path:
|
||||
@@ -1777,6 +2240,11 @@ class ModelUpdateHandler:
|
||||
status=429,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -1837,6 +2305,10 @@ class ModelUpdateHandler:
|
||||
if target_model_ids:
|
||||
target_model_ids = sorted(set(target_model_ids))
|
||||
|
||||
folder_path: Optional[str] = payload.get("folder_path")
|
||||
if folder_path is not None and not isinstance(folder_path, str):
|
||||
folder_path = None
|
||||
|
||||
provider = await self._get_civitai_provider()
|
||||
if provider is None:
|
||||
return web.json_response(
|
||||
@@ -1851,6 +2323,7 @@ class ModelUpdateHandler:
|
||||
provider,
|
||||
force_refresh=force_refresh,
|
||||
target_model_ids=target_model_ids or None,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
if self._service.scanner.is_cancelled():
|
||||
return web.json_response(
|
||||
@@ -1865,15 +2338,29 @@ class ModelUpdateHandler:
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error(
|
||||
"Failed to refresh model updates: %s", exc, exc_info=True
|
||||
)
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
hide_early_access = False
|
||||
if self._settings is not None:
|
||||
try:
|
||||
hide_early_access = bool(
|
||||
self._settings.get("hide_early_access_updates", False)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
serialized_records = []
|
||||
for record in records.values():
|
||||
has_update_fn = getattr(record, "has_update", None)
|
||||
if callable(has_update_fn) and has_update_fn():
|
||||
if callable(has_update_fn) and has_update_fn(
|
||||
hide_early_access=hide_early_access
|
||||
):
|
||||
serialized_records.append(self._serialize_record(record))
|
||||
|
||||
return web.json_response(
|
||||
@@ -2253,7 +2740,7 @@ class ModelUpdateHandler:
|
||||
self,
|
||||
record,
|
||||
*,
|
||||
version_context: Optional[Dict[int, Dict[str, Optional[str]]]] = None,
|
||||
version_context: Optional[Dict[int, Dict[str, Any]]] = None,
|
||||
) -> Dict:
|
||||
context = version_context or {}
|
||||
# Check user setting for hiding early access versions
|
||||
@@ -2282,7 +2769,7 @@ class ModelUpdateHandler:
|
||||
|
||||
@staticmethod
|
||||
def _serialize_version(
|
||||
version, context: Optional[Dict[str, Optional[str]]]
|
||||
version, context: Optional[Dict[str, Any]]
|
||||
) -> Dict:
|
||||
context = context or {}
|
||||
preview_override = context.get("preview_override")
|
||||
@@ -2316,17 +2803,42 @@ class ModelUpdateHandler:
|
||||
"sizeBytes": version.size_bytes,
|
||||
"previewUrl": preview_url,
|
||||
"isInLibrary": version.is_in_library,
|
||||
"hasBeenDownloaded": bool(context.get("has_been_downloaded", False)),
|
||||
"shouldIgnore": version.should_ignore,
|
||||
"earlyAccessEndsAt": version.early_access_ends_at,
|
||||
"isEarlyAccess": is_early_access,
|
||||
"usageControl": version.usage_control,
|
||||
"filePath": context.get("file_path"),
|
||||
"fileName": context.get("file_name"),
|
||||
}
|
||||
|
||||
async def _build_version_context(
|
||||
self, record
|
||||
) -> Dict[int, Dict[str, Optional[str]]]:
|
||||
context: Dict[int, Dict[str, Optional[str]]] = {}
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
context: Dict[int, Dict[str, Any]] = {}
|
||||
downloaded_version_ids: set[int] = set()
|
||||
try:
|
||||
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||
downloaded_version_ids = set(
|
||||
await history_service.get_downloaded_version_ids(
|
||||
record.model_type,
|
||||
record.model_id,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.debug(
|
||||
"Failed to load download history while building version context: %s",
|
||||
exc,
|
||||
)
|
||||
|
||||
for version in record.versions:
|
||||
context[version.version_id] = {
|
||||
"file_path": None,
|
||||
"file_name": None,
|
||||
"preview_override": None,
|
||||
"has_been_downloaded": version.version_id in downloaded_version_ids,
|
||||
}
|
||||
|
||||
try:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
@@ -2345,16 +2857,21 @@ class ModelUpdateHandler:
|
||||
cache_entry = version_index.get(version.version_id)
|
||||
if isinstance(cache_entry, Mapping):
|
||||
preview = cache_entry.get("preview_url")
|
||||
context_entry: Dict[str, Optional[str]] = {
|
||||
"file_path": cache_entry.get("file_path"),
|
||||
"file_name": cache_entry.get("file_name"),
|
||||
"preview_override": None,
|
||||
}
|
||||
context_entry = context.setdefault(
|
||||
version.version_id,
|
||||
{
|
||||
"file_path": None,
|
||||
"file_name": None,
|
||||
"preview_override": None,
|
||||
"has_been_downloaded": version.version_id in downloaded_version_ids,
|
||||
},
|
||||
)
|
||||
context_entry["file_path"] = cache_entry.get("file_path")
|
||||
context_entry["file_name"] = cache_entry.get("file_name")
|
||||
if isinstance(preview, str) and preview:
|
||||
context_entry["preview_override"] = config.get_preview_static_url(
|
||||
preview
|
||||
)
|
||||
context[version.version_id] = context_entry
|
||||
return context
|
||||
|
||||
|
||||
@@ -2378,8 +2895,10 @@ class ModelHandlerSet:
|
||||
return {
|
||||
"handle_models_page": self.page_view.handle,
|
||||
"get_models": self.listing.get_models,
|
||||
"get_excluded_models": self.listing.get_excluded_models,
|
||||
"delete_model": self.management.delete_model,
|
||||
"exclude_model": self.management.exclude_model,
|
||||
"unexclude_model": self.management.unexclude_model,
|
||||
"fetch_civitai": self.management.fetch_civitai,
|
||||
"fetch_all_civitai": self.civitai.fetch_all_civitai,
|
||||
"relink_civitai": self.management.relink_civitai,
|
||||
@@ -2403,9 +2922,24 @@ class ModelHandlerSet:
|
||||
"download_model": self.download.download_model,
|
||||
"download_model_get": self.download.download_model_get,
|
||||
"cancel_download_get": self.download.cancel_download_get,
|
||||
"skip_download_get": self.download.skip_download_get,
|
||||
"pause_download_get": self.download.pause_download_get,
|
||||
"resume_download_get": self.download.resume_download_get,
|
||||
"get_download_progress": self.download.get_download_progress,
|
||||
"get_download_queue": self.download.get_download_queue,
|
||||
"add_to_download_queue": self.download.add_to_download_queue,
|
||||
"remove_from_download_queue": self.download.remove_from_download_queue,
|
||||
"move_queue_item_to_top": self.download.move_queue_item_to_top,
|
||||
"move_queue_item_to_end": self.download.move_queue_item_to_end,
|
||||
"clear_download_queue": self.download.clear_download_queue,
|
||||
"get_download_history": self.download.get_download_history,
|
||||
"clear_download_history": self.download.clear_download_history,
|
||||
"delete_download_history_item": self.download.delete_download_history_item,
|
||||
"retry_download_from_history": self.download.retry_download_from_history,
|
||||
"retry_all_failed_downloads": self.download.retry_all_failed_downloads,
|
||||
"complete_download_in_queue": self.download.complete_download_in_queue,
|
||||
"get_download_stats": self.download.get_download_stats,
|
||||
"update_download_queue_status": self.download.update_download_queue_status,
|
||||
"get_civitai_versions": self.civitai.get_civitai_versions,
|
||||
"get_civitai_model_by_version": self.civitai.get_civitai_model_by_version,
|
||||
"get_civitai_model_by_hash": self.civitai.get_civitai_model_by_hash,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,6 +13,12 @@ from ...config import config as global_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CHUNK_SIZE = 1024 * 1024 # 1 MB — balance between streaming iteration overhead and per-chunk memory
|
||||
|
||||
# Video file extensions that bypass native sendfile on Windows
|
||||
# to avoid IOCP/ProactorEventLoop crashes during client disconnect.
|
||||
_VIDEO_EXTENSIONS = frozenset({".mp4", ".webm", ".mov", ".avi", ".mkv"})
|
||||
|
||||
|
||||
class PreviewHandler:
|
||||
"""Serve preview assets for the active library at request time."""
|
||||
@@ -48,8 +55,58 @@ class PreviewHandler:
|
||||
logger.debug("Preview file not found at %s", str(resolved))
|
||||
raise web.HTTPNotFound(text="Preview file not found")
|
||||
|
||||
# aiohttp's FileResponse handles range requests and content headers for us.
|
||||
return web.FileResponse(path=resolved, chunk_size=256 * 1024)
|
||||
# aiohttp's FileResponse handles range requests, content headers, and
|
||||
# uses kernel sendfile (zero-copy DMA) on Linux/macOS. On Windows it
|
||||
# uses IOCP-based _sendfile_native which can crash when the client
|
||||
# disconnects mid-transfer during fast scrolling. The _stream_file()
|
||||
# fallback is kept for a future compat toggle.
|
||||
#
|
||||
# Set explicit Cache-Control so the browser can cache video (and image)
|
||||
# previews across VirtualScroller recycling cycles. Without this,
|
||||
# Chrome does not cache 206 Partial Content responses for <video>
|
||||
# elements, causing the same video to be re-downloaded on every scroll.
|
||||
resp = web.FileResponse(path=resolved, chunk_size=_CHUNK_SIZE)
|
||||
resp.headers["Cache-Control"] = "public, max-age=86400"
|
||||
return resp
|
||||
|
||||
async def _stream_file(
|
||||
self, request: web.Request, path: Path
|
||||
) -> web.StreamResponse:
|
||||
"""Stream a file chunk-by-chunk, bypassing native sendfile.
|
||||
|
||||
This avoids the Windows IOCP ``_sendfile_native`` crash that occurs
|
||||
when the client disconnects during a large file transfer.
|
||||
"""
|
||||
content_type, _ = mimetypes.guess_type(str(path))
|
||||
if content_type is None:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
file_size = path.stat().st_size
|
||||
resp = web.StreamResponse()
|
||||
resp.content_type = content_type
|
||||
resp.content_length = file_size
|
||||
|
||||
# Allow browser caching: video previews rarely change during a session.
|
||||
# The frontend already appends ?t={version} to bust cache on update.
|
||||
resp.headers["Cache-Control"] = "public, max-age=86400"
|
||||
|
||||
await resp.prepare(request)
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
await resp.write(chunk)
|
||||
except (ConnectionResetError, ConnectionAbortedError):
|
||||
# Client disconnected during streaming — expected when scrolling
|
||||
# rapidly through a library with animated previews.
|
||||
pass
|
||||
except OSError as exc:
|
||||
logger.debug("I/O error streaming preview %s: %s", path, exc)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
__all__ = ["PreviewHandler"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,11 +22,17 @@ class RouteDefinition:
|
||||
MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/settings", "get_settings"),
|
||||
RouteDefinition("POST", "/api/lm/settings", "update_settings"),
|
||||
RouteDefinition("GET", "/api/lm/doctor/diagnostics", "get_doctor_diagnostics"),
|
||||
RouteDefinition("POST", "/api/lm/doctor/repair-cache", "repair_doctor_cache"),
|
||||
RouteDefinition("POST", "/api/lm/doctor/resolve-filename-conflicts", "resolve_doctor_filename_conflicts"),
|
||||
RouteDefinition("POST", "/api/lm/doctor/export-bundle", "export_doctor_bundle"),
|
||||
RouteDefinition("GET", "/api/lm/priority-tags", "get_priority_tags"),
|
||||
RouteDefinition("GET", "/api/lm/settings/libraries", "get_settings_libraries"),
|
||||
RouteDefinition("POST", "/api/lm/settings/libraries/activate", "activate_library"),
|
||||
RouteDefinition("GET", "/api/lm/health-check", "health_check"),
|
||||
RouteDefinition("GET", "/api/lm/supporters", "get_supporters"),
|
||||
RouteDefinition("GET", "/api/lm/wildcards/search", "search_wildcards"),
|
||||
RouteDefinition("POST", "/api/lm/wildcards/open-location", "open_wildcards_location"),
|
||||
RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"),
|
||||
RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"),
|
||||
RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"),
|
||||
@@ -37,13 +43,74 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||
RouteDefinition("GET", "/api/lm/check-models-exist", "check_models_exist"),
|
||||
RouteDefinition(
|
||||
"GET",
|
||||
"/api/lm/model-version-download-status",
|
||||
"get_model_version_download_status",
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST",
|
||||
"/api/lm/model-version-download-status",
|
||||
"set_model_version_download_status",
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET",
|
||||
"/api/lm/set-model-version-download-status",
|
||||
"set_model_version_download_status",
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/backup/status", "get_backup_status"),
|
||||
RouteDefinition("POST", "/api/lm/backup/export", "export_backup"),
|
||||
RouteDefinition("POST", "/api/lm/backup/import", "import_backup"),
|
||||
RouteDefinition("POST", "/api/lm/backup/open-location", "open_backup_location"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/model-versions-status", "get_model_versions_status"
|
||||
),
|
||||
RouteDefinition("POST", "/api/lm/settings/open-location", "open_settings_location"),
|
||||
RouteDefinition("GET", "/api/lm/custom-words/search", "search_custom_words"),
|
||||
RouteDefinition("GET", "/api/lm/example-workflows", "get_example_workflows"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/example-workflows/{filename}", "get_example_workflow"
|
||||
),
|
||||
# Base model management routes
|
||||
RouteDefinition("GET", "/api/lm/base-models", "get_base_models"),
|
||||
RouteDefinition("POST", "/api/lm/base-models/refresh", "refresh_base_models"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/base-models/categories", "get_base_model_categories"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/base-models/cache-status", "get_base_model_cache_status"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/delete-model-version", "delete_model_version"
|
||||
),
|
||||
# Hugging Face model endpoints
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/hf-repo-files", "get_hf_repo_files"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/download-hf-model", "download_hf_model"
|
||||
),
|
||||
# Agent skill endpoints
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/agent/skills", "get_agent_skills"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/agent/execute/{skill_name}", "execute_agent_skill"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/agent/cancel", "cancel_agent_skill"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -67,7 +134,11 @@ class MiscRouteRegistrar:
|
||||
definitions: Iterable[RouteDefinition] = MISC_ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
for definition in definitions:
|
||||
self._bind(definition.method, definition.path, handler_lookup[definition.handler_name])
|
||||
self._bind(
|
||||
definition.method,
|
||||
definition.path,
|
||||
handler_lookup[definition.handler_name],
|
||||
)
|
||||
|
||||
def _bind(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
|
||||
@@ -19,9 +19,12 @@ from ..services.downloader import get_downloader
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from .handlers.misc_handlers import (
|
||||
CustomWordsHandler,
|
||||
DoctorHandler,
|
||||
ExampleWorkflowsHandler,
|
||||
FileSystemHandler,
|
||||
HealthCheckHandler,
|
||||
LoraCodeHandler,
|
||||
BackupHandler,
|
||||
MetadataArchiveHandler,
|
||||
MiscHandlerSet,
|
||||
ModelExampleFilesHandler,
|
||||
@@ -32,15 +35,20 @@ from .handlers.misc_handlers import (
|
||||
SupportersHandler,
|
||||
TrainedWordsHandler,
|
||||
UsageStatsHandler,
|
||||
WildcardsHandler,
|
||||
build_service_registry_adapter,
|
||||
)
|
||||
from .handlers.base_model_handlers import BaseModelHandlerSet
|
||||
from .handlers.hf_handlers import HfHandler
|
||||
from .handlers.agent_handlers import AgentHandler
|
||||
from .misc_route_registrar import MiscRouteRegistrar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get(
|
||||
"HF_HUB_DISABLE_TELEMETRY", "0"
|
||||
) == "0"
|
||||
standalone_mode = (
|
||||
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
|
||||
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
)
|
||||
|
||||
|
||||
class MiscRoutes:
|
||||
@@ -75,7 +83,9 @@ class MiscRoutes:
|
||||
self._node_registry = node_registry or NodeRegistry()
|
||||
self._standalone_mode = standalone_mode_flag
|
||||
|
||||
self._handler_mapping: Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]] | None = None
|
||||
self._handler_mapping: (
|
||||
Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]] | None
|
||||
) = None
|
||||
|
||||
@staticmethod
|
||||
def setup_routes(app: web.Application) -> None:
|
||||
@@ -87,7 +97,9 @@ class MiscRoutes:
|
||||
registrar = self._registrar_factory(app)
|
||||
registrar.register_routes(self._ensure_handler_mapping())
|
||||
|
||||
def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
def _ensure_handler_mapping(
|
||||
self,
|
||||
) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
@@ -109,6 +121,7 @@ class MiscRoutes:
|
||||
settings_service=self._settings,
|
||||
metadata_provider_updater=self._metadata_provider_updater,
|
||||
)
|
||||
backup = BackupHandler()
|
||||
filesystem = FileSystemHandler(settings_service=self._settings)
|
||||
node_registry_handler = NodeRegistryHandler(
|
||||
node_registry=self._node_registry,
|
||||
@@ -120,7 +133,13 @@ class MiscRoutes:
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
)
|
||||
custom_words = CustomWordsHandler()
|
||||
wildcards = WildcardsHandler()
|
||||
supporters = SupportersHandler()
|
||||
doctor = DoctorHandler(settings_service=self._settings)
|
||||
example_workflows = ExampleWorkflowsHandler()
|
||||
base_model = BaseModelHandlerSet()
|
||||
hf_handler = HfHandler()
|
||||
agent_handler = AgentHandler()
|
||||
|
||||
return self._handler_set_factory(
|
||||
health=health,
|
||||
@@ -132,9 +151,16 @@ class MiscRoutes:
|
||||
node_registry=node_registry_handler,
|
||||
model_library=model_library,
|
||||
metadata_archive=metadata_archive,
|
||||
backup=backup,
|
||||
filesystem=filesystem,
|
||||
custom_words=custom_words,
|
||||
wildcards=wildcards,
|
||||
supporters=supporters,
|
||||
doctor=doctor,
|
||||
example_workflows=example_workflows,
|
||||
base_model=base_model,
|
||||
hf_handler=hf_handler,
|
||||
agent_handler=agent_handler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,10 @@ class RouteDefinition:
|
||||
|
||||
COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/excluded", "get_excluded_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/unexclude", "unexclude_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
||||
@@ -99,11 +101,46 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/skip-download", "skip_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/download-progress/{download_id}", "get_download_progress"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/downloads/queue", "get_download_queue"),
|
||||
RouteDefinition("GET", "/api/lm/downloads/queue/add", "add_to_download_queue"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/queue/remove", "remove_from_download_queue"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/queue/move-to-top", "move_queue_item_to_top"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/queue/move-to-end", "move_queue_item_to_end"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/queue/clear", "clear_download_queue"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/downloads/history", "get_download_history"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/history/clear", "clear_download_history"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/history/delete", "delete_download_history_item"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/history/retry", "retry_download_from_history"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/history/retry-all", "retry_all_failed_downloads"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/downloads/stats", "get_download_stats"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/queue/complete", "complete_download_in_queue"
|
||||
),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/downloads/queue/status", "update_download_queue_status"
|
||||
),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/cancel-task", "cancel_task"),
|
||||
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Route registrar for recipe endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -22,7 +23,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"
|
||||
),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
@@ -30,9 +33,13 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"),
|
||||
@@ -40,13 +47,40 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/for-checkpoint", "get_recipes_for_checkpoint"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/{recipe_id}/repair", "repair_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/repair-bulk", "repair_recipes_bulk"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/repair-progress", "get_repair_progress"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/batch-import/start", "start_batch_import"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/batch-import/progress", "get_batch_import_progress"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/batch-import/cancel", "cancel_batch_import"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/batch-import/directory", "start_directory_import"
|
||||
),
|
||||
RouteDefinition("POST", "/api/lm/recipes/browse-directory", "browse_directory"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/check-image-exists", "check_image_exists"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipes/import-from-url", "import_from_url"),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/create-from-example", "create_from_example"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipe/{recipe_id}/reimport", "reimport_recipe"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -63,7 +97,9 @@ class RecipeRouteRegistrar:
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||
def register_routes(
|
||||
self, handler_lookup: Mapping[str, Callable[[web.Request], object]]
|
||||
) -> None:
|
||||
for definition in ROUTE_DEFINITIONS:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
@@ -11,6 +11,8 @@ from ..config import config
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.model_query import normalize_sub_type, resolve_sub_type
|
||||
from ..utils.constants import VALID_LORA_SUB_TYPES, VALID_CHECKPOINT_SUB_TYPES
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -140,6 +142,21 @@ class StatsRoutes:
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# CivitAI model type distribution across all model types
|
||||
# Use the same logic as the filter panel: normalize_sub_type(resolve_sub_type(entry))
|
||||
# with sub-type validation per model type
|
||||
model_types_counter: Counter[str] = Counter()
|
||||
for entry in lora_cache.raw_data:
|
||||
ntype = normalize_sub_type(resolve_sub_type(entry))
|
||||
if ntype and ntype in VALID_LORA_SUB_TYPES:
|
||||
model_types_counter[ntype] += 1
|
||||
for entry in checkpoint_cache.raw_data:
|
||||
ntype = normalize_sub_type(resolve_sub_type(entry))
|
||||
if ntype and ntype in VALID_CHECKPOINT_SUB_TYPES:
|
||||
model_types_counter[ntype] += 1
|
||||
# Embeddings: always count as "embedding" regardless of CivitAI sub-type
|
||||
model_types_counter['embedding'] = len(embedding_cache.raw_data)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
@@ -154,7 +171,8 @@ class StatsRoutes:
|
||||
'total_generations': usage_data.get('total_executions', 0),
|
||||
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
|
||||
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})),
|
||||
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {})),
|
||||
'model_types_distribution': dict(model_types_counter.most_common())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -459,9 +477,12 @@ class StatsRoutes:
|
||||
if unused_lora_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused LoRAs',
|
||||
'description': f'{unused_lora_percent:.1f}% of your LoRAs ({unused_loras}/{total_loras}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused models to free up storage space.'
|
||||
'key': 'insights.unusedLoras.high',
|
||||
'params': {
|
||||
'percent': f'{unused_lora_percent:.1f}',
|
||||
'count': str(unused_loras),
|
||||
'total': str(total_loras)
|
||||
}
|
||||
})
|
||||
|
||||
if total_checkpoints > 0:
|
||||
@@ -469,9 +490,12 @@ class StatsRoutes:
|
||||
if unused_checkpoint_percent > 30:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'Unused Checkpoints Detected',
|
||||
'description': f'{unused_checkpoint_percent:.1f}% of your checkpoints ({unused_checkpoints}/{total_checkpoints}) have never been used.',
|
||||
'suggestion': 'Review and consider removing checkpoints you no longer need.'
|
||||
'key': 'insights.unusedCheckpoints.detected',
|
||||
'params': {
|
||||
'percent': f'{unused_checkpoint_percent:.1f}',
|
||||
'count': str(unused_checkpoints),
|
||||
'total': str(total_checkpoints)
|
||||
}
|
||||
})
|
||||
|
||||
if total_embeddings > 0:
|
||||
@@ -479,9 +503,12 @@ class StatsRoutes:
|
||||
if unused_embedding_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused Embeddings',
|
||||
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
|
||||
'key': 'insights.unusedEmbeddings.high',
|
||||
'params': {
|
||||
'percent': f'{unused_embedding_percent:.1f}',
|
||||
'count': str(unused_embeddings),
|
||||
'total': str(total_embeddings)
|
||||
}
|
||||
})
|
||||
|
||||
# Storage insights
|
||||
@@ -492,18 +519,20 @@ class StatsRoutes:
|
||||
if total_size > 100 * 1024 * 1024 * 1024: # 100GB
|
||||
insights.append({
|
||||
'type': 'info',
|
||||
'title': 'Large Collection Detected',
|
||||
'description': f'Your model collection is using {self._format_size(total_size)} of storage.',
|
||||
'suggestion': 'Consider using external storage or cloud solutions for better organization.'
|
||||
'key': 'insights.collection.large',
|
||||
'params': {
|
||||
'size': self._format_size(total_size)
|
||||
}
|
||||
})
|
||||
|
||||
# Recent activity insight
|
||||
if usage_data.get('total_executions', 0) > 100:
|
||||
insights.append({
|
||||
'type': 'success',
|
||||
'title': 'Active User',
|
||||
'description': f'You\'ve completed {usage_data["total_executions"]} generations so far!',
|
||||
'suggestion': 'Keep exploring and creating amazing content with your models.'
|
||||
'key': 'insights.activity.active',
|
||||
'params': {
|
||||
'count': str(usage_data['total_executions'])
|
||||
}
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
import toml
|
||||
import git
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -11,11 +10,33 @@ from typing import Dict, List
|
||||
|
||||
from ..utils.settings_paths import ensure_settings_file
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NETWORK_EXCEPTIONS = (ClientError, OSError, asyncio.TimeoutError)
|
||||
|
||||
# User-managed directories that live inside the plugin folder (portable
|
||||
# mode) and must survive a Git-based update. ``git clean -fd`` would
|
||||
# otherwise delete them because they are untracked and, in released tags,
|
||||
# not listed in ``.gitignore``. ``-e`` excludes a path from cleaning
|
||||
# regardless of whether it is ignored.
|
||||
_PRESERVE_DIRS = ('settings.json', 'civitai', 'wildcards', 'backups', 'stats', 'logs', 'cache', 'model_cache')
|
||||
|
||||
|
||||
def _clean_excludes() -> List[str]:
|
||||
"""Build the ``-e`` arguments for ``git clean`` from :data:`_PRESERVE_DIRS`."""
|
||||
excludes: List[str] = []
|
||||
for name in _PRESERVE_DIRS:
|
||||
excludes.append('-e')
|
||||
excludes.append(name)
|
||||
# For directories, also exclude nested matches explicitly
|
||||
# (``-e dir`` alone matches the dir entry; ``-e dir/**`` guards
|
||||
# contents under all git versions as defense-in-depth).
|
||||
excludes.append('-e')
|
||||
excludes.append(f'{name}/**')
|
||||
return excludes
|
||||
|
||||
|
||||
class UpdateRoutes:
|
||||
"""Routes for handling plugin update checks"""
|
||||
@@ -212,8 +233,19 @@ class UpdateRoutes:
|
||||
|
||||
zip_path = tmp_zip_path
|
||||
|
||||
# Skip both settings.json, civitai and model cache folder
|
||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai', 'model_cache'])
|
||||
# Close the downloaded-versions SQLite connection before cleaning,
|
||||
# so that shutil.rmtree() does not fail on Windows (the process
|
||||
# cannot delete a file with an outstanding open handle).
|
||||
try:
|
||||
history_svc = ServiceRegistry._services.get("downloaded_version_history_service")
|
||||
if history_svc is not None:
|
||||
history_svc.close()
|
||||
logger.info("Closed downloaded-version history database connection")
|
||||
except Exception:
|
||||
logger.debug("Could not close downloaded-version history database", exc_info=True)
|
||||
|
||||
# Skip settings.json, civitai, model cache and runtime cache folders
|
||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai', 'model_cache', 'cache', 'wildcards', 'backups', 'stats'])
|
||||
|
||||
# Extract ZIP to temp dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -222,16 +254,17 @@ class UpdateRoutes:
|
||||
# Find extracted folder (GitHub ZIP contains a root folder)
|
||||
extracted_root = next(os.scandir(tmp_dir)).path
|
||||
|
||||
# Copy files, skipping settings.json and civitai folder
|
||||
# Copy files, skipping user data that should be preserved
|
||||
skip_items = {'settings.json', 'civitai', 'wildcards', 'backups', 'stats'}
|
||||
for item in os.listdir(extracted_root):
|
||||
if item == 'settings.json' or item == 'civitai':
|
||||
if item in skip_items:
|
||||
continue
|
||||
src = os.path.join(extracted_root, item)
|
||||
dst = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(src):
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
|
||||
shutil.copytree(src, dst, ignore=shutil.ignore_patterns(*skip_items))
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
@@ -239,15 +272,17 @@ class UpdateRoutes:
|
||||
# for ComfyUI Manager to work properly
|
||||
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
||||
tracking_files = []
|
||||
skip_tracked = {'civitai', 'wildcards', 'backups', 'stats'}
|
||||
for root, dirs, files in os.walk(extracted_root):
|
||||
# Skip civitai folder and its contents
|
||||
# Skip user data directories and their contents
|
||||
rel_root = os.path.relpath(root, extracted_root)
|
||||
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
|
||||
top_dir = rel_root.split(os.sep)[0] if rel_root != '.' else ''
|
||||
if top_dir in skip_tracked:
|
||||
continue
|
||||
for file in files:
|
||||
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
||||
# Skip settings.json and any file under civitai
|
||||
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
|
||||
# Skip settings.json and any file under user data dirs
|
||||
if rel_path == 'settings.json' or rel_path.split(os.sep)[0] in skip_tracked:
|
||||
continue
|
||||
tracking_files.append(rel_path.replace("\\", "/"))
|
||||
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
||||
@@ -342,6 +377,17 @@ class UpdateRoutes:
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
import git
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"GitPython is not available: the git executable was not found in PATH. "
|
||||
"Install git or set $GIT_PYTHON_GIT_EXECUTABLE to the git binary path."
|
||||
)
|
||||
return False, ""
|
||||
|
||||
clean_excludes = _clean_excludes()
|
||||
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
@@ -353,8 +399,9 @@ class UpdateRoutes:
|
||||
if nightly:
|
||||
# Reset to discard any local changes
|
||||
repo.git.reset('--hard')
|
||||
# Clean untracked files
|
||||
repo.git.clean('-fd')
|
||||
# Clean untracked files, but preserve user-managed directories
|
||||
# (wildcards, backups, stats, civitai, caches, settings.json).
|
||||
repo.git.clean('-fd', *clean_excludes)
|
||||
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
@@ -371,8 +418,9 @@ class UpdateRoutes:
|
||||
else:
|
||||
# Reset to discard any local changes
|
||||
repo.git.reset('--hard')
|
||||
# Clean untracked files
|
||||
repo.git.clean('-fd')
|
||||
# Clean untracked files, but preserve user-managed directories
|
||||
# (wildcards, backups, stats, civitai, caches, settings.json).
|
||||
repo.git.clean('-fd', *clean_excludes)
|
||||
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
@@ -438,6 +486,7 @@ class UpdateRoutes:
|
||||
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
||||
return git_info
|
||||
|
||||
import git
|
||||
repo = git.Repo(plugin_root)
|
||||
commit = repo.head.commit
|
||||
git_info['commit_hash'] = commit.hexsha
|
||||
|
||||
23
py/services/agent/__init__.py
Normal file
23
py/services/agent/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Agent-powered skill system for LoRA Manager.
|
||||
|
||||
This package provides the orchestration layer for LLM/agent-powered features.
|
||||
Skills define *what* to do (prompt template). The :class:`AgentService`
|
||||
handles *how* (LLM calls, context gathering, validation, progress).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .skill_definition import SkillDefinition, SkillPermissions
|
||||
from .skill_registry import SkillRegistry
|
||||
from .agent_service import AgentService, AgentProgressReporter, SkillResult
|
||||
from .post_processor import PostProcessor
|
||||
|
||||
__all__ = [
|
||||
"AgentProgressReporter",
|
||||
"AgentService",
|
||||
"PostProcessor",
|
||||
"SkillDefinition",
|
||||
"SkillPermissions",
|
||||
"SkillRegistry",
|
||||
"SkillResult",
|
||||
]
|
||||
413
py/services/agent/agent_service.py
Normal file
413
py/services/agent/agent_service.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""Agent orchestration service.
|
||||
|
||||
The :class:`AgentService` coordinates skill execution:
|
||||
|
||||
1. Look up the skill in :class:`SkillRegistry`
|
||||
2. Validate input against the skill's ``input_schema``
|
||||
3. Prepare context via :mod:`~py.agent_cli` (read metadata, list base models, fetch HF README)
|
||||
4. If ``llm_required``: call :class:`LLMService` with the rendered prompt
|
||||
5. Post-process via :class:`PostProcessor` (delegates I/O to :mod:`~py.agent_cli`)
|
||||
6. Broadcast progress and completion via :class:`WebSocketManager`
|
||||
|
||||
Skills define *what* to do (prompt template). The AgentService handles *how*
|
||||
(LLM calls, context gathering, validation, progress).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..llm_service import LLMService
|
||||
from ..websocket_manager import ws_manager
|
||||
from .post_processor import PostProcessor
|
||||
from .skill_registry import SkillRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentProgressReporter:
|
||||
"""Protocol-compatible progress reporter backed by WebSocket broadcast."""
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
await ws_manager.broadcast(payload)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillResult:
|
||||
"""Outcome of a skill execution."""
|
||||
|
||||
success: bool
|
||||
updated_models: List[Dict[str, Any]] = field(default_factory=list)
|
||||
errors: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
|
||||
|
||||
def _validate_schema(data: Any, schema: Dict[str, Any], path: str = "") -> List[str]:
|
||||
"""Minimal JSON schema validator.
|
||||
|
||||
Supports a subset of JSON Schema: ``type``, ``properties``, ``required``,
|
||||
``items``, ``enum``. Returns a list of error messages (empty = valid).
|
||||
"""
|
||||
|
||||
errors: List[str] = []
|
||||
if not schema:
|
||||
return errors
|
||||
|
||||
expected_type = schema.get("type")
|
||||
if expected_type:
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
expected_py = type_map.get(expected_type)
|
||||
if expected_py is not None and not isinstance(data, expected_py):
|
||||
errors.append(f"{path or 'root'}: expected {expected_type}, got {type(data).__name__}")
|
||||
return errors
|
||||
|
||||
if expected_type == "object" and isinstance(data, dict):
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
for req_key in required:
|
||||
if req_key not in data:
|
||||
errors.append(f"{path or 'root'}: missing required property '{req_key}'")
|
||||
for key, value in data.items():
|
||||
if key in properties:
|
||||
errors.extend(_validate_schema(value, properties[key], f"{path}.{key}"))
|
||||
|
||||
if expected_type == "array" and isinstance(data, list):
|
||||
items_schema = schema.get("items")
|
||||
if items_schema:
|
||||
for i, item in enumerate(data):
|
||||
errors.extend(_validate_schema(item, items_schema, f"{path}[{i}]"))
|
||||
|
||||
if "enum" in schema and data not in schema["enum"]:
|
||||
errors.append(f"{path or 'root'}: value '{data}' not in enum {schema['enum']}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt template rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_prompt(template: str, variables: Dict[str, Any]) -> str:
|
||||
"""Render a prompt template with ``{{variable}}`` placeholders.
|
||||
|
||||
Uses simple regex substitution — no Jinja2 dependency needed.
|
||||
"""
|
||||
|
||||
def replace(match: re.Match) -> str:
|
||||
key = match.group(1).strip()
|
||||
value = variables.get(key, "")
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False, indent=2)
|
||||
return str(value)
|
||||
|
||||
return re.sub(r"\{\{(\w+)\}\}", replace, template)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""Orchestrate agent skill execution.
|
||||
|
||||
Usage::
|
||||
|
||||
service = await AgentService.get_instance()
|
||||
result = await service.execute_skill(
|
||||
skill_name="enrich_hf_metadata",
|
||||
input_data={"model_paths": ["/path/to/model.safetensors"]},
|
||||
progress_callback=AgentProgressReporter(),
|
||||
)
|
||||
"""
|
||||
|
||||
_instance: Optional["AgentService"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
skill_registry: Optional[SkillRegistry] = None,
|
||||
llm_service: Optional[LLMService] = None,
|
||||
) -> None:
|
||||
self._registry = skill_registry
|
||||
self._llm_service = llm_service
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "AgentService":
|
||||
"""Return the lazily-initialised global ``AgentService``."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(
|
||||
skill_registry=await SkillRegistry.get_instance(),
|
||||
llm_service=await LLMService.get_instance(),
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
async def _ensure_registry(self) -> SkillRegistry:
|
||||
if self._registry is None:
|
||||
self._registry = await SkillRegistry.get_instance()
|
||||
return self._registry
|
||||
|
||||
async def _ensure_llm(self) -> LLMService:
|
||||
if self._llm_service is None:
|
||||
self._llm_service = await LLMService.get_instance()
|
||||
return self._llm_service
|
||||
|
||||
async def list_skills(self) -> List[Dict[str, Any]]:
|
||||
"""Return a JSON-serialisable list of available skills."""
|
||||
|
||||
registry = await self._ensure_registry()
|
||||
return [
|
||||
{
|
||||
"name": s.name,
|
||||
"title": s.title,
|
||||
"description": s.description,
|
||||
"llm_required": s.llm_required,
|
||||
"model_type_filter": s.model_type_filter,
|
||||
}
|
||||
for s in registry.list_skills()
|
||||
]
|
||||
|
||||
async def execute_skill(
|
||||
self,
|
||||
*,
|
||||
skill_name: str,
|
||||
input_data: Dict[str, Any],
|
||||
progress_callback: Optional[AgentProgressReporter] = None,
|
||||
) -> SkillResult:
|
||||
"""Execute an agent skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill to execute
|
||||
input_data: Input validated against the skill's ``input_schema``
|
||||
progress_callback: Optional WebSocket progress reporter
|
||||
|
||||
Returns:
|
||||
:class:`SkillResult` with success status and updated model info
|
||||
"""
|
||||
|
||||
registry = await self._ensure_registry()
|
||||
logger.info("execute_skill '%s': looking up skill", skill_name)
|
||||
skill = registry.get_skill(skill_name)
|
||||
if skill is None:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=[f"Skill not found: {skill_name}"],
|
||||
summary=f"Skill '{skill_name}' does not exist",
|
||||
)
|
||||
|
||||
input_errors = _validate_schema(input_data, skill.input_schema)
|
||||
if input_errors:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=input_errors,
|
||||
summary=f"Invalid input: {'; '.join(input_errors)}",
|
||||
)
|
||||
|
||||
model_paths = input_data.get("model_paths", [])
|
||||
if not model_paths:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=["No model_paths provided"],
|
||||
summary="No models to process",
|
||||
)
|
||||
|
||||
total = len(model_paths)
|
||||
processed = 0
|
||||
success_count = 0
|
||||
updated_models: List[Dict[str, Any]] = []
|
||||
errors: List[str] = []
|
||||
post_processor = PostProcessor()
|
||||
|
||||
logger.info("execute_skill '%s': starting with %d model(s)", skill_name, total)
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="started",
|
||||
total=total, processed=0, success=0,
|
||||
)
|
||||
|
||||
llm = await self._ensure_llm()
|
||||
llm_configured = llm.is_configured() if skill.llm_required else True
|
||||
|
||||
for model_path in model_paths:
|
||||
logger.info(
|
||||
"execute_skill '%s': processing model %d/%d: %s",
|
||||
skill_name, processed + 1, total, model_path,
|
||||
)
|
||||
try:
|
||||
from ...agent_cli import read_metadata
|
||||
metadata = await read_metadata(model_path)
|
||||
|
||||
prompt_vars: Dict[str, Any] = {"model_path": model_path}
|
||||
if skill.llm_required and llm_configured:
|
||||
prompt_vars = await self._build_prompt_context(
|
||||
skill_name, model_path, metadata, registry, llm,
|
||||
)
|
||||
|
||||
llm_response: Optional[Dict[str, Any]] = None
|
||||
if skill.llm_required and llm_configured:
|
||||
prompt_template = registry.load_prompt(skill_name)
|
||||
rendered = _render_prompt(prompt_template, prompt_vars)
|
||||
logger.info(
|
||||
"execute_skill '%s': LLM call for %s (prompt=%d chars)",
|
||||
skill_name, model_path, len(rendered),
|
||||
)
|
||||
llm_response = await llm.chat_completion_json(
|
||||
system_prompt=prompt_vars.get(
|
||||
"system_prompt",
|
||||
"You are a helpful assistant that extracts structured metadata.",
|
||||
),
|
||||
user_prompt=rendered,
|
||||
)
|
||||
|
||||
model_result = await post_processor.process(
|
||||
skill_name=skill_name,
|
||||
model_path=model_path,
|
||||
llm_output=llm_response or {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if model_result.get("success", True):
|
||||
success_count += 1
|
||||
uf = model_result.get("updated_fields", [])
|
||||
if uf:
|
||||
updated_models.append({"path": model_path, "updated_fields": uf})
|
||||
else:
|
||||
errors.extend(
|
||||
model_result.get("errors", [model_result.get("error", "Unknown error")])
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Skill %s failed for %s: %s", skill_name, model_path, exc)
|
||||
errors.append(f"{model_path}: {exc}")
|
||||
|
||||
processed += 1
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="processing",
|
||||
total=total, processed=processed, success=success_count,
|
||||
current_path=model_path,
|
||||
)
|
||||
|
||||
result = SkillResult(
|
||||
success=success_count > 0,
|
||||
updated_models=updated_models,
|
||||
errors=errors,
|
||||
summary=f"Processed {processed}/{total} models, {success_count} succeeded",
|
||||
)
|
||||
|
||||
logger.info("execute_skill '%s': done — %s", skill_name, result.summary)
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="completed",
|
||||
total=total, processed=processed, success=success_count,
|
||||
updated_models=updated_models, errors=errors, summary=result.summary,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _build_prompt_context(
|
||||
self,
|
||||
skill_name: str,
|
||||
model_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
registry: SkillRegistry,
|
||||
llm: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gather variables for the skill's prompt template.
|
||||
|
||||
Reads metadata, fetches the HF README (if applicable), lists available
|
||||
base models, and returns a dict that maps to ``{{variable}}``
|
||||
placeholders in ``prompt.md``.
|
||||
"""
|
||||
from ...agent_cli import list_base_models
|
||||
|
||||
context: Dict[str, Any] = {
|
||||
"model_path": model_path,
|
||||
"hf_url": "",
|
||||
"repo": "",
|
||||
"readme_content": "",
|
||||
"current_metadata": {},
|
||||
"base_models": [],
|
||||
}
|
||||
|
||||
context["current_metadata"] = {
|
||||
"file_name": metadata.get("file_name", ""),
|
||||
"base_model": metadata.get("base_model", ""),
|
||||
"tags": metadata.get("tags", []),
|
||||
"modelDescription": metadata.get("modelDescription", ""),
|
||||
"trainedWords": metadata.get("trainedWords", []),
|
||||
"sha256": (metadata.get("sha256") or "")[:16] + "..." if metadata.get("sha256") else "",
|
||||
"size": metadata.get("size", 0),
|
||||
}
|
||||
|
||||
hf_url = metadata.get("hf_url", "")
|
||||
context["hf_url"] = hf_url
|
||||
repo = self._extract_repo_from_url(hf_url) if hf_url else ""
|
||||
context["repo"] = repo or ""
|
||||
if repo:
|
||||
readme = await self._fetch_readme(repo)
|
||||
context["readme_content"] = readme[:8000] if readme else "(README not available)"
|
||||
|
||||
try:
|
||||
context["base_models"] = await list_base_models()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to list base models: %s", exc)
|
||||
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _extract_repo_from_url(hf_url: str) -> Optional[str]:
|
||||
"""Extract ``user/repo`` from a HuggingFace URL."""
|
||||
if not hf_url:
|
||||
return None
|
||||
m = re.match(r"https?://huggingface\.co/([^/]+/[^/]+)", hf_url)
|
||||
return m.group(1) if m else None
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_readme(repo: str) -> str:
|
||||
"""Fetch README.md from HuggingFace (tries ``main``, then ``master``)."""
|
||||
async with aiohttp.ClientSession(
|
||||
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as session:
|
||||
for branch in ("main", "master"):
|
||||
url = f"https://huggingface.co/{repo}/raw/{branch}/README.md"
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.text()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to fetch README from %s: %s", url, exc)
|
||||
return ""
|
||||
|
||||
async def _emit_progress(
|
||||
self,
|
||||
callback: Optional[AgentProgressReporter],
|
||||
skill_name: str,
|
||||
*,
|
||||
status: str,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Send a progress update via WebSocket (if callback is set)."""
|
||||
payload: Dict[str, Any] = {"type": "agent_progress", "skill": skill_name, "status": status}
|
||||
payload.update(extra)
|
||||
if callback is not None:
|
||||
await callback.on_progress(payload)
|
||||
168
py/services/agent/post_processor.py
Normal file
168
py/services/agent/post_processor.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Post-processing engine for agent skill outputs.
|
||||
|
||||
The :class:`PostProcessor` takes the LLM's structured JSON output and applies
|
||||
it to a model's on-disk metadata via the :mod:`~py.agent_cli` functions.
|
||||
|
||||
It handles all the skill-specific business logic — conditions, transformations,
|
||||
and orchestration of multiple side-effects (write metadata, download preview,
|
||||
refresh cache). All actual I/O is delegated to :mod:`~py.agent_cli`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostProcessor:
|
||||
"""Deterministic post-processor for agent skill outputs.
|
||||
|
||||
Usage (called by :class:`~py.services.agent.agent_service.AgentService`)::
|
||||
|
||||
processor = PostProcessor()
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/path/to/model.safetensors",
|
||||
llm_output={...},
|
||||
metadata={...}, # from agent_cli.read_metadata()
|
||||
)
|
||||
"""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
*,
|
||||
skill_name: str,
|
||||
model_path: str,
|
||||
llm_output: Dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Route *llm_output* to the correct skill post-processor.
|
||||
|
||||
Returns a dict with keys ``success`` (bool), ``updated_fields`` (list),
|
||||
``preview_downloaded`` (bool), and ``errors`` (list).
|
||||
"""
|
||||
if skill_name == "enrich_hf_metadata":
|
||||
return await self._process_enrich_hf_metadata(
|
||||
model_path, llm_output, metadata,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"updated_fields": [],
|
||||
"errors": [f"No post-processor registered for skill: {skill_name}"],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# enrich_hf_metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _process_enrich_hf_metadata(
|
||||
self,
|
||||
model_path: str,
|
||||
llm_output: Dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
from ...agent_cli import (
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
updated_fields: List[str] = []
|
||||
preview_downloaded = False
|
||||
|
||||
# -- Determine whether this is an HF-sourced model -----------------
|
||||
is_hf_model = not metadata.get("from_civitai", True)
|
||||
|
||||
# -- Collect updates -----------------------------------------------
|
||||
updates: Dict[str, Any] = {}
|
||||
|
||||
# base_model
|
||||
new_base = (llm_output.get("base_model") or "").strip()
|
||||
current_base = metadata.get("base_model", "") or ""
|
||||
if new_base and self._should_overwrite(current_base, is_hf_model):
|
||||
updates["base_model"] = new_base
|
||||
|
||||
# trainedWords / trigger words
|
||||
new_triggers = llm_output.get("trigger_words", [])
|
||||
if isinstance(new_triggers, list):
|
||||
cleaned = [t.strip() for t in new_triggers if t.strip()]
|
||||
if cleaned:
|
||||
current_triggers = metadata.get("trainedWords") or []
|
||||
if self._should_overwrite_list(current_triggers, is_hf_model):
|
||||
updates["trainedWords"] = cleaned
|
||||
|
||||
# modelDescription
|
||||
new_desc = (llm_output.get("description") or "").strip()
|
||||
if new_desc:
|
||||
current_desc = metadata.get("modelDescription", "") or ""
|
||||
if self._should_overwrite(current_desc, is_hf_model):
|
||||
updates["modelDescription"] = new_desc
|
||||
|
||||
# tags — merge with existing, deduplicate (case-insensitive)
|
||||
new_tags = llm_output.get("tags", [])
|
||||
if isinstance(new_tags, list) and new_tags:
|
||||
existing_tags = metadata.get("tags") or []
|
||||
merged = self._merge_tags(existing_tags, new_tags)
|
||||
if len(merged) > len(existing_tags) or is_hf_model:
|
||||
updates["tags"] = merged
|
||||
|
||||
# metadata_source & llm_enriched_at (always set)
|
||||
updates["metadata_source"] = "agent:enrich_hf_metadata"
|
||||
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# -- Persist updates ------------------------------------------------
|
||||
if updates:
|
||||
updated_fields = await apply_metadata_updates(model_path, updates)
|
||||
|
||||
# -- Download preview -----------------------------------------------
|
||||
preview_url = (llm_output.get("preview_url") or "").strip()
|
||||
current_preview = metadata.get("preview_url") or ""
|
||||
if preview_url and not (current_preview and os.path.exists(current_preview)):
|
||||
preview_downloaded = await download_preview(model_path, preview_url)
|
||||
|
||||
# -- Refresh scanner cache ------------------------------------------
|
||||
if updated_fields or preview_downloaded:
|
||||
await refresh_cache(model_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"updated_fields": updated_fields,
|
||||
"preview_downloaded": preview_downloaded,
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _should_overwrite(current_value: str, is_hf_model: bool) -> bool:
|
||||
"""Return ``True`` when a scalar field should be overwritten."""
|
||||
return is_hf_model or not current_value or current_value.lower() in (
|
||||
"", "unknown",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_overwrite_list(current_list: List[str], is_hf_model: bool) -> bool:
|
||||
"""Return ``True`` when a list field should be overwritten."""
|
||||
return is_hf_model or not current_list
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(existing: List[str], new: List[str]) -> List[str]:
|
||||
"""Merge *new* tags into *existing*, all lowercased.
|
||||
|
||||
This matches the behaviour of :class:`TagUpdateService` which
|
||||
normalises every tag to lowercase for case-insensitive dedup.
|
||||
"""
|
||||
merged: List[str] = []
|
||||
seen: set = set()
|
||||
for tag in list(existing) + list(new):
|
||||
t = tag.strip().lower()
|
||||
if t and t not in seen:
|
||||
merged.append(t)
|
||||
seen.add(t)
|
||||
return merged
|
||||
45
py/services/agent/skill_definition.py
Normal file
45
py/services/agent/skill_definition.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Skill definition data structures.
|
||||
|
||||
Each skill is described by a :class:`SkillDefinition` that declares its
|
||||
input/output schemas, whether it needs an LLM call, and what permissions
|
||||
its post-processor has.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillPermissions:
|
||||
"""Declarative permission scope for a skill's post-processor.
|
||||
|
||||
These are auditable constraints — the :class:`AgentService` checks them
|
||||
before invoking the handler. They are defense-in-depth, not a sandbox.
|
||||
"""
|
||||
|
||||
write_metadata: bool = True
|
||||
write_previews: bool = True
|
||||
network_domains: Tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillDefinition:
|
||||
"""Immutable description of an agent skill."""
|
||||
|
||||
name: str
|
||||
title: str
|
||||
description: str
|
||||
llm_required: bool
|
||||
input_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
output_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
model_type_filter: Optional[List[str]] = None
|
||||
permissions: SkillPermissions = field(default_factory=SkillPermissions)
|
||||
|
||||
def applies_to_model_type(self, model_type: str) -> bool:
|
||||
"""Return ``True`` if this skill can run on the given model type."""
|
||||
|
||||
if self.model_type_filter is None:
|
||||
return True
|
||||
return model_type in self.model_type_filter
|
||||
184
py/services/agent/skill_registry.py
Normal file
184
py/services/agent/skill_registry.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Discovery and loading of agent skills.
|
||||
|
||||
Skills live in ``py/services/agent/skills/<name>/`` directories. Each
|
||||
directory must contain a ``SKILL.md`` file with YAML frontmatter::
|
||||
|
||||
---
|
||||
name: my_skill
|
||||
title: "My Skill"
|
||||
description: "What this skill does"
|
||||
llm_required: true
|
||||
---
|
||||
|
||||
Prompt template with ``{{variable}}`` placeholders.
|
||||
|
||||
The registry scans the skills directory on first access and caches results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from .skill_definition import SkillDefinition, SkillPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directory where built-in skills are stored
|
||||
_SKILLS_DIR = Path(__file__).parent / "skills"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frontmatter parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FRONTMATTER_RE = re.compile(
|
||||
r"^---\s*\n(.*?\n)---\s*\n?(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
def _parse_skill_file(path: Path) -> tuple[dict, str]:
|
||||
"""Read a ``SKILL.md`` file and return (frontmatter_dict, body_text).
|
||||
|
||||
Raises ``ValueError`` if the file lacks valid YAML frontmatter.
|
||||
"""
|
||||
text = path.read_text(encoding="utf-8")
|
||||
m = _FRONTMATTER_RE.match(text)
|
||||
if not m:
|
||||
raise ValueError(f"Missing or invalid YAML frontmatter in {path}")
|
||||
frontmatter = yaml.safe_load(m.group(1))
|
||||
if not isinstance(frontmatter, dict):
|
||||
raise ValueError(f"Frontmatter in {path} is not a mapping")
|
||||
body = m.group(2).strip()
|
||||
return frontmatter, body
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""Discover and load agent skills from the filesystem."""
|
||||
|
||||
_instance: Optional["SkillRegistry"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, skills_dir: Path = _SKILLS_DIR) -> None:
|
||||
self._skills_dir = skills_dir
|
||||
self._skills: Dict[str, SkillDefinition] = {}
|
||||
self._loaded: bool = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Singleton access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "SkillRegistry":
|
||||
"""Return the lazily-initialised global ``SkillRegistry``."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
registry = cls()
|
||||
registry._discover()
|
||||
cls._instance = registry
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Discovery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover(self) -> None:
|
||||
"""Scan the skills directory and load all valid skill definitions."""
|
||||
|
||||
self._skills.clear()
|
||||
if not self._skills_dir.is_dir():
|
||||
logger.warning("Skills directory does not exist: %s", self._skills_dir)
|
||||
self._loaded = True
|
||||
return
|
||||
|
||||
for entry in sorted(self._skills_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
skill_md = entry / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
continue
|
||||
try:
|
||||
definition = self._load_skill_definition(skill_md)
|
||||
if definition is not None:
|
||||
self._skills[definition.name] = definition
|
||||
logger.debug("Loaded skill: %s", definition.name)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load skill from %s: %s", skill_md, exc)
|
||||
|
||||
self._loaded = True
|
||||
logger.info("Discovered %d agent skills", len(self._skills))
|
||||
|
||||
def _load_skill_definition(self, path: Path) -> Optional[SkillDefinition]:
|
||||
"""Parse a ``SKILL.md`` frontmatter into a :class:`SkillDefinition`."""
|
||||
|
||||
try:
|
||||
data, _body = _parse_skill_file(path)
|
||||
except (ValueError, yaml.YAMLError) as exc:
|
||||
logger.warning("Failed to parse SKILL.md %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
if "name" not in data:
|
||||
logger.warning("SKILL.md missing required 'name' field: %s", path)
|
||||
return None
|
||||
|
||||
perm_data = data.get("permissions", {})
|
||||
permissions = SkillPermissions(
|
||||
write_metadata=perm_data.get("write_metadata", True),
|
||||
write_previews=perm_data.get("write_previews", True),
|
||||
network_domains=tuple(perm_data.get("network_domains", [])),
|
||||
)
|
||||
|
||||
return SkillDefinition(
|
||||
name=data["name"],
|
||||
title=data.get("title", data["name"]),
|
||||
description=data.get("description", ""),
|
||||
llm_required=data.get("llm_required", False),
|
||||
input_schema=data.get("input_schema", {}),
|
||||
output_schema=data.get("output_schema", {}),
|
||||
model_type_filter=data.get("model_type_filter"),
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_skills(self) -> List[SkillDefinition]:
|
||||
"""Return all discovered skill definitions."""
|
||||
|
||||
if not self._loaded:
|
||||
self._discover()
|
||||
return list(self._skills.values())
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillDefinition]:
|
||||
"""Return the skill definition for ``name``, or ``None`` if not found."""
|
||||
|
||||
if not self._loaded:
|
||||
self._discover()
|
||||
return self._skills.get(name)
|
||||
|
||||
def load_prompt(self, name: str) -> str:
|
||||
"""Load and return the prompt template body from a skill's ``SKILL.md``."""
|
||||
|
||||
skill_dir = self._skills_dir / name
|
||||
skill_path = skill_dir / "SKILL.md"
|
||||
if not skill_path.exists():
|
||||
raise FileNotFoundError(f"SKILL.md not found: {skill_path}")
|
||||
try:
|
||||
_frontmatter, body = _parse_skill_file(skill_path)
|
||||
return body
|
||||
except (ValueError, yaml.YAMLError) as exc:
|
||||
raise ValueError(f"Failed to parse prompt from {skill_path}: {exc}") from exc
|
||||
89
py/services/agent/skills/enrich_hf_metadata/SKILL.md
Normal file
89
py/services/agent/skills/enrich_hf_metadata/SKILL.md
Normal file
@@ -0,0 +1,89 @@
|
||||
---
|
||||
name: enrich_hf_metadata
|
||||
title: "Enrich Metadata from HuggingFace"
|
||||
description: >
|
||||
Parse the HuggingFace model card via LLM to extract description, trigger
|
||||
words, base model, tags, and preview image URL.
|
||||
llm_required: true
|
||||
---
|
||||
|
||||
You are an expert assistant for AI image generation models. Your task is to extract structured metadata from a HuggingFace model card (README.md).
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Repository**: {{hf_url}}
|
||||
- **Model file path**: {{model_path}}
|
||||
- **Repository ID**: {{repo}}
|
||||
|
||||
## Current Metadata (may be incomplete)
|
||||
|
||||
```json
|
||||
{{current_metadata}}
|
||||
```
|
||||
|
||||
## Available Base Models
|
||||
|
||||
The following base models are currently valid in this system:
|
||||
{{base_models}}
|
||||
|
||||
## HuggingFace README Content
|
||||
|
||||
```
|
||||
{{readme_content}}
|
||||
```
|
||||
|
||||
## Extraction Instructions
|
||||
|
||||
Extract the following information from the README content above:
|
||||
|
||||
### base_model
|
||||
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
|
||||
|
||||
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
|
||||
|
||||
### trigger_words
|
||||
The trigger words or activation prompts needed to use this LoRA. Look for:
|
||||
- `instance_prompt:` in the YAML frontmatter
|
||||
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
|
||||
- Example prompts at the start (usually the first word or phrase before any description)
|
||||
Return as an array of strings. If none found, return an empty array.
|
||||
|
||||
### description
|
||||
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
|
||||
|
||||
### tags
|
||||
3-8 relevant tags for categorizing this model. Extract from:
|
||||
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
|
||||
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
|
||||
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
|
||||
All lowercase, no spaces. Return empty array if none found.
|
||||
|
||||
### preview_url
|
||||
The URL of the most suitable preview image from the README. Look for image tags (e.g. ``) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
|
||||
|
||||
### confidence
|
||||
Your confidence level in the extracted data:
|
||||
- "high" — most fields were explicitly stated in the README
|
||||
- "medium" — some fields were inferred from context
|
||||
- "low" — most fields are guesses based on limited information
|
||||
|
||||
## Output Format
|
||||
|
||||
Return ONLY a JSON object with exactly these fields (no markdown fences, no extra text):
|
||||
|
||||
```json
|
||||
{
|
||||
"model_path": "{{model_path}}",
|
||||
"base_model": "<canonical name or empty string>",
|
||||
"trigger_words": ["<word1>", "<word2>"],
|
||||
"description": "<1-2 sentence summary>",
|
||||
"tags": ["<tag1>", "<tag2>"],
|
||||
"preview_url": "<image URL or empty string>",
|
||||
"confidence": "<high|medium|low>"
|
||||
}
|
||||
```
|
||||
|
||||
Important:
|
||||
- Only include the JSON object, no other text
|
||||
- If a field cannot be determined, use an empty string or empty array
|
||||
- Do not fabricate information not supported by the README
|
||||
672
py/services/aria2_downloader.py
Normal file
672
py/services/aria2_downloader.py
Normal file
@@ -0,0 +1,672 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .downloader import DownloadProgress, get_downloader, is_ssl_cert_verify_error
|
||||
from .aria2_transfer_state import Aria2TransferStateStore
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _try_certifi_ca_path() -> str | None:
|
||||
"""Return the certifi CA bundle path if available, else None."""
|
||||
try:
|
||||
import certifi # type: ignore[import-untyped]
|
||||
|
||||
path = certifi.where()
|
||||
if os.path.isfile(path):
|
||||
logger.debug(
|
||||
"aria2 --ca-certificate: using certifi CA bundle at %s", path
|
||||
)
|
||||
return path
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger.debug("aria2 --ca-certificate: certifi not available")
|
||||
return None
|
||||
|
||||
|
||||
CIVITAI_DOWNLOAD_URL_PREFIXES = (
|
||||
"https://civitai.com/api/download/",
|
||||
"https://civitai.red/api/download/",
|
||||
)
|
||||
|
||||
|
||||
class Aria2Error(RuntimeError):
|
||||
"""Raised when aria2 integration fails."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Aria2Transfer:
|
||||
"""Track an aria2 download registered by the Python coordinator."""
|
||||
|
||||
gid: str
|
||||
save_path: str
|
||||
|
||||
|
||||
class Aria2Downloader:
|
||||
"""Manage an aria2 RPC daemon for recommended model downloads."""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "Aria2Downloader":
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self._process: Optional[asyncio.subprocess.Process] = None
|
||||
self._rpc_port: Optional[int] = None
|
||||
self._rpc_secret = ""
|
||||
self._rpc_url = ""
|
||||
self._rpc_session: Optional[aiohttp.ClientSession] = None
|
||||
self._rpc_session_lock = asyncio.Lock()
|
||||
self._process_lock = asyncio.Lock()
|
||||
self._transfers: Dict[str, Aria2Transfer] = {}
|
||||
self._poll_interval = 0.5
|
||||
self._state_store = Aria2TransferStateStore()
|
||||
self._stderr_reader_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._process is not None and self._process.returncode is None
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
progress_callback=None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Download a file using aria2 RPC and wait for completion."""
|
||||
|
||||
await self._ensure_process()
|
||||
save_path = os.path.abspath(save_path)
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None or os.path.abspath(transfer.save_path) != save_path:
|
||||
gid = await self._schedule_download(
|
||||
url,
|
||||
save_path,
|
||||
download_id=download_id,
|
||||
headers=headers,
|
||||
)
|
||||
transfer = Aria2Transfer(gid=gid, save_path=save_path)
|
||||
self._transfers[download_id] = transfer
|
||||
|
||||
try:
|
||||
while True:
|
||||
status = await self._get_status_with_retry(download_id)
|
||||
if status is None:
|
||||
return False, "aria2 download not found"
|
||||
|
||||
snapshot = self._build_progress_snapshot(status)
|
||||
if progress_callback is not None:
|
||||
await self._dispatch_progress(progress_callback, snapshot)
|
||||
|
||||
state = status.get("status", "")
|
||||
if state == "complete":
|
||||
completed_path = self._resolve_completed_path(status, save_path)
|
||||
return True, completed_path
|
||||
if state == "error":
|
||||
return False, status.get("errorMessage") or "aria2 download failed"
|
||||
if state == "removed":
|
||||
return False, "Download was cancelled"
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
finally:
|
||||
self._transfers.pop(download_id, None)
|
||||
|
||||
async def _get_status_with_retry(
|
||||
self, download_id: str, *, max_retries: int = 4, retry_delay: float = 3.0
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Call get_status with retry for transient RPC failures.
|
||||
|
||||
Only retries on :exc:`Aria2Error` (RPC-level failure). Returns
|
||||
``None`` immediately when the download_id is not tracked (a missing
|
||||
transfer is not a transient condition, so retrying is pointless).
|
||||
|
||||
A single failed RPC call should not immediately fail the download,
|
||||
because aria2 may be temporarily busy (e.g. finalizing multiple
|
||||
concurrent downloads) and a retry will often succeed.
|
||||
"""
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await self.get_status(download_id)
|
||||
except Aria2Error as exc:
|
||||
last_exc = exc
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"aria2 get_status transient failure (attempt %d/%d) for %s: %s",
|
||||
attempt + 1, max_retries, download_id, exc,
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
raise Aria2Error(
|
||||
f"Failed to query aria2 download status after {max_retries} attempts: {last_exc}"
|
||||
) from last_exc
|
||||
|
||||
async def _schedule_download(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
*,
|
||||
download_id: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
save_dir = os.path.dirname(save_path)
|
||||
out_name = os.path.basename(save_path)
|
||||
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resolved_url = url
|
||||
request_headers = headers
|
||||
if headers and url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES):
|
||||
resolved_url = await self._resolve_authenticated_redirect_url(url, headers)
|
||||
if resolved_url != url:
|
||||
request_headers = None
|
||||
logger.debug(
|
||||
"Resolved Civitai download %s to signed URL for aria2",
|
||||
download_id,
|
||||
)
|
||||
|
||||
options: Dict[str, str] = {
|
||||
"dir": save_dir,
|
||||
"out": out_name,
|
||||
"continue": "true",
|
||||
"max-connection-per-server": "4",
|
||||
"split": "4",
|
||||
"min-split-size": "1M",
|
||||
"allow-overwrite": "true",
|
||||
"auto-file-renaming": "false",
|
||||
"file-allocation": "none",
|
||||
}
|
||||
if request_headers:
|
||||
options["header"] = [
|
||||
f"{key}: {value}" for key, value in request_headers.items()
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Submitting aria2 download %s -> %s (auth=%s, civitai_signed=%s)",
|
||||
download_id,
|
||||
save_path,
|
||||
bool(request_headers),
|
||||
resolved_url != url,
|
||||
)
|
||||
|
||||
try:
|
||||
gid = await self._rpc_call("aria2.addUri", [[resolved_url], options])
|
||||
except Exception as exc:
|
||||
raise Aria2Error(f"Failed to schedule aria2 download: {exc}") from exc
|
||||
|
||||
logger.debug("aria2 accepted download %s with gid %s", download_id, gid)
|
||||
await self._state_store.upsert(
|
||||
download_id,
|
||||
{
|
||||
"gid": gid,
|
||||
"save_path": save_path,
|
||||
"status": "downloading",
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
return gid
|
||||
|
||||
async def get_status(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the raw aria2 status payload for a known download."""
|
||||
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [transfer.gid, keys])
|
||||
except Exception as exc:
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def get_status_by_gid(self, gid: str) -> Optional[Dict[str, Any]]:
|
||||
keys = [
|
||||
"gid",
|
||||
"status",
|
||||
"totalLength",
|
||||
"completedLength",
|
||||
"downloadSpeed",
|
||||
"errorMessage",
|
||||
"files",
|
||||
]
|
||||
try:
|
||||
status = await self._rpc_call("aria2.tellStatus", [gid, keys])
|
||||
except Exception as exc:
|
||||
message = str(exc)
|
||||
if "cannot be found" in message.lower() or "not found" in message.lower():
|
||||
return None
|
||||
raise Aria2Error(f"Failed to query aria2 download status: {exc}") from exc
|
||||
|
||||
if isinstance(status, dict):
|
||||
return status
|
||||
return None
|
||||
|
||||
async def restore_transfer(self, download_id: str, gid: str, save_path: str) -> None:
|
||||
await self._ensure_process()
|
||||
self._transfers[download_id] = Aria2Transfer(
|
||||
gid=gid,
|
||||
save_path=os.path.abspath(save_path),
|
||||
)
|
||||
|
||||
async def reassign_transfer(
|
||||
self, from_download_id: str, to_download_id: str
|
||||
) -> Optional[Aria2Transfer]:
|
||||
transfer = self._transfers.get(from_download_id)
|
||||
if transfer is None:
|
||||
return None
|
||||
|
||||
self._transfers[to_download_id] = transfer
|
||||
if from_download_id != to_download_id:
|
||||
self._transfers.pop(from_download_id, None)
|
||||
return transfer
|
||||
|
||||
async def has_transfer(self, download_id: str) -> bool:
|
||||
return download_id in self._transfers
|
||||
|
||||
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.forcePause", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "paused"})
|
||||
return {"success": True, "message": "Download paused successfully"}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.unpause", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.upsert(download_id, {"status": "downloading"})
|
||||
return {"success": True, "message": "Download resumed successfully"}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
transfer = self._transfers.get(download_id)
|
||||
if transfer is None:
|
||||
return {"success": False, "error": "Download task not found"}
|
||||
|
||||
try:
|
||||
await self._rpc_call("aria2.forceRemove", [transfer.gid])
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
await self._state_store.remove(download_id)
|
||||
return {"success": True, "message": "Download cancelled successfully"}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Shut down the RPC process and session."""
|
||||
|
||||
# Cancel the background stderr reader first so it stops reading
|
||||
# from the pipe before the subprocess is terminated.
|
||||
if self._stderr_reader_task is not None:
|
||||
self._stderr_reader_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._stderr_reader_task, timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
self._stderr_reader_task = None
|
||||
|
||||
if self._rpc_session is not None:
|
||||
await self._rpc_session.close()
|
||||
self._rpc_session = None
|
||||
|
||||
process = self._process
|
||||
self._process = None
|
||||
self._transfers.clear()
|
||||
|
||||
if process is None:
|
||||
return
|
||||
|
||||
if process.returncode is None:
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
async def _drain_stderr(self) -> None:
|
||||
"""Continuously drain aria2's stderr pipe so it never blocks.
|
||||
|
||||
When the 64 KB pipe buffer fills up, aria2's ``write()`` to stderr
|
||||
blocks, which freezes the entire ``aria2c`` process — including its
|
||||
RPC handler. This background task reads lines from stderr as they
|
||||
arrive and forwards them to Python's logger.
|
||||
"""
|
||||
try:
|
||||
assert self._process is not None and self._process.stderr is not None
|
||||
async for line in self._process.stderr:
|
||||
text = line.decode("utf-8", errors="replace").rstrip()
|
||||
if text:
|
||||
logger.debug("aria2 stderr: %s", text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
|
||||
try:
|
||||
result = callback(snapshot, snapshot)
|
||||
except TypeError:
|
||||
result = callback(snapshot.percent_complete)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
elif hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
def _build_progress_snapshot(self, status: Dict[str, Any]) -> DownloadProgress:
|
||||
completed = self._parse_int(status.get("completedLength"))
|
||||
total = self._parse_int(status.get("totalLength"))
|
||||
speed = float(self._parse_int(status.get("downloadSpeed")))
|
||||
percent = 0.0
|
||||
if total > 0:
|
||||
percent = (completed / total) * 100.0
|
||||
|
||||
return DownloadProgress(
|
||||
percent_complete=max(0.0, min(percent, 100.0)),
|
||||
bytes_downloaded=completed,
|
||||
total_bytes=total or None,
|
||||
bytes_per_second=speed,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
|
||||
def _resolve_completed_path(self, status: Dict[str, Any], default_path: str) -> str:
|
||||
files = status.get("files")
|
||||
if isinstance(files, list) and files:
|
||||
first = files[0]
|
||||
if isinstance(first, dict):
|
||||
candidate = first.get("path")
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
return default_path
|
||||
|
||||
@staticmethod
|
||||
def _parse_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
async def _resolve_authenticated_redirect_url(
|
||||
self,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
) -> str:
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
request_headers = dict(downloader.default_headers)
|
||||
request_headers.update(headers)
|
||||
request_headers["Accept-Encoding"] = "identity"
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
url,
|
||||
headers=request_headers,
|
||||
allow_redirects=False,
|
||||
proxy=downloader.proxy_url,
|
||||
) as response:
|
||||
if response.status in {301, 302, 303, 307, 308}:
|
||||
location = response.headers.get("Location")
|
||||
if location:
|
||||
return location
|
||||
raise Aria2Error(
|
||||
"Authenticated Civitai redirect did not include a Location header"
|
||||
)
|
||||
|
||||
if response.status == 200:
|
||||
return url
|
||||
|
||||
body = await response.text()
|
||||
raise Aria2Error(
|
||||
f"Failed to resolve authenticated Civitai redirect: status={response.status} body={body[:300]}"
|
||||
)
|
||||
except aiohttp.ClientError as exc:
|
||||
if is_ssl_cert_verify_error(exc):
|
||||
logger.error(
|
||||
"SSL certificate verification failed during Civitai redirect "
|
||||
"resolution for %s. This is usually caused by an outdated CA "
|
||||
"certificate bundle. Recommended fixes:\n"
|
||||
" 1. pip install --upgrade certifi\n"
|
||||
" 2. pip install pip-system-certs",
|
||||
url,
|
||||
)
|
||||
raise Aria2Error(
|
||||
f"Failed to resolve authenticated Civitai redirect: {exc}"
|
||||
) from exc
|
||||
|
||||
async def _ensure_process(self) -> None:
|
||||
async with self._process_lock:
|
||||
if self.is_running and await self._ping():
|
||||
return
|
||||
|
||||
await self.close()
|
||||
|
||||
executable = self._resolve_executable()
|
||||
self._rpc_port = self._find_free_port()
|
||||
self._rpc_secret = secrets.token_hex(16)
|
||||
self._rpc_url = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
||||
|
||||
command = [
|
||||
executable,
|
||||
"--enable-rpc=true",
|
||||
"--rpc-listen-all=false",
|
||||
f"--rpc-listen-port={self._rpc_port}",
|
||||
f"--rpc-secret={self._rpc_secret}",
|
||||
"--check-certificate=true",
|
||||
# Point aria2 at certifi's CA bundle when available so it uses
|
||||
# the same certificate store as Python downloads.
|
||||
*((
|
||||
f"--ca-certificate={ca_cert}",
|
||||
) if (ca_cert := _try_certifi_ca_path()) else ()),
|
||||
"--allow-overwrite=true",
|
||||
"--auto-file-renaming=false",
|
||||
"--file-allocation=none",
|
||||
"--max-concurrent-downloads=5",
|
||||
"--continue=true",
|
||||
"--daemon=false",
|
||||
"--quiet=true",
|
||||
f"--stop-with-process={os.getpid()}",
|
||||
]
|
||||
|
||||
logger.info("Starting aria2 RPC daemon from %s", executable)
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
await self._wait_until_ready()
|
||||
|
||||
# Drain aria2's stderr in a background task so the pipe buffer
|
||||
# never fills up. If the pipe blocks, aria2 itself freezes and
|
||||
# cannot respond to RPC — this was the root cause of the
|
||||
# "Failed to query aria2 download status" timeout bug.
|
||||
# Must start AFTER _wait_until_ready to avoid a race where the
|
||||
# drain task consumes aria2's early-exit error message before
|
||||
# _wait_until_ready can read it.
|
||||
self._stderr_reader_task = asyncio.create_task(
|
||||
self._drain_stderr()
|
||||
)
|
||||
|
||||
def _resolve_executable(self) -> str:
|
||||
settings = get_settings_manager()
|
||||
configured_path = (settings.get("aria2c_path") or "").strip()
|
||||
candidate = configured_path or "aria2c"
|
||||
|
||||
resolved = shutil.which(candidate)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
if configured_path and os.path.isfile(configured_path) and os.access(
|
||||
configured_path, os.X_OK
|
||||
):
|
||||
return configured_path
|
||||
|
||||
raise Aria2Error(
|
||||
"aria2c executable was not found. Install aria2 or configure aria2c_path."
|
||||
)
|
||||
|
||||
async def _wait_until_ready(self) -> None:
|
||||
assert self._process is not None
|
||||
|
||||
start_time = asyncio.get_running_loop().time()
|
||||
last_error = ""
|
||||
while asyncio.get_running_loop().time() - start_time < 10.0:
|
||||
if self._process.returncode is not None:
|
||||
stderr_output = ""
|
||||
if self._process.stderr is not None:
|
||||
try:
|
||||
stderr_output = (
|
||||
await asyncio.wait_for(self._process.stderr.read(), timeout=0.2)
|
||||
).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
stderr_output = ""
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC process exited early with code {self._process.returncode}: {stderr_output.strip()}"
|
||||
)
|
||||
|
||||
try:
|
||||
if await self._ping():
|
||||
return
|
||||
except Exception as exc: # pragma: no cover - startup race
|
||||
last_error = str(exc)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
raise Aria2Error(
|
||||
f"Timed out waiting for aria2 RPC to become ready{': ' + last_error if last_error else ''}"
|
||||
)
|
||||
|
||||
async def _ping(self) -> bool:
|
||||
try:
|
||||
result = await self._rpc_call("aria2.getVersion", [])
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return isinstance(result, dict)
|
||||
|
||||
async def _rpc_call(self, method: str, params: list[Any]) -> Any:
|
||||
if not self._rpc_url:
|
||||
raise Aria2Error("aria2 RPC endpoint is not initialized")
|
||||
|
||||
session = await self._get_rpc_session()
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": secrets.token_hex(8),
|
||||
"method": method,
|
||||
"params": [f"token:{self._rpc_secret}", *params],
|
||||
}
|
||||
|
||||
async with session.post(self._rpc_url, json=payload) as response:
|
||||
text = await response.text()
|
||||
|
||||
try:
|
||||
body = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
body = None
|
||||
|
||||
if body is None:
|
||||
if response.status != 200:
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC returned status {response.status} with non-JSON body: {text}"
|
||||
)
|
||||
raise Aria2Error(f"Invalid aria2 RPC response: {text}")
|
||||
|
||||
if "error" in body:
|
||||
error = body["error"] or {}
|
||||
code = error.get("code") if isinstance(error, dict) else None
|
||||
message = error.get("message") if isinstance(error, dict) else str(error)
|
||||
logger.error(
|
||||
"aria2 RPC %s failed with HTTP %s, code=%s, message=%s",
|
||||
method,
|
||||
response.status,
|
||||
code,
|
||||
message,
|
||||
)
|
||||
status_message = (
|
||||
f"aria2 RPC {method} failed with status {response.status}: {message}"
|
||||
if response.status != 200
|
||||
else message
|
||||
)
|
||||
raise Aria2Error(status_message or "Unknown aria2 RPC error")
|
||||
|
||||
if response.status != 200:
|
||||
logger.error(
|
||||
"aria2 RPC %s returned unexpected HTTP status %s without error payload: %s",
|
||||
method,
|
||||
response.status,
|
||||
body,
|
||||
)
|
||||
raise Aria2Error(
|
||||
f"aria2 RPC {method} returned unexpected status {response.status}"
|
||||
)
|
||||
|
||||
return body.get("result")
|
||||
|
||||
async def _get_rpc_session(self) -> aiohttp.ClientSession:
|
||||
if self._rpc_session is None or self._rpc_session.closed:
|
||||
async with self._rpc_session_lock:
|
||||
if self._rpc_session is None or self._rpc_session.closed:
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=None, sock_connect=10, sock_read=60
|
||||
)
|
||||
self._rpc_session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._rpc_session
|
||||
|
||||
@staticmethod
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(1)
|
||||
return int(sock.getsockname()[1])
|
||||
|
||||
|
||||
async def get_aria2_downloader() -> Aria2Downloader:
|
||||
"""Get the singleton aria2 downloader."""
|
||||
|
||||
return await Aria2Downloader.get_instance()
|
||||
108
py/services/aria2_transfer_state.py
Normal file
108
py/services/aria2_transfer_state.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..utils.cache_paths import get_cache_base_dir
|
||||
|
||||
|
||||
def get_aria2_state_path() -> str:
|
||||
base_dir = get_cache_base_dir(create=True)
|
||||
state_dir = os.path.join(base_dir, "aria2")
|
||||
os.makedirs(state_dir, exist_ok=True)
|
||||
return os.path.join(state_dir, "downloads.json")
|
||||
|
||||
|
||||
class Aria2TransferStateStore:
|
||||
"""Persist aria2 transfer metadata needed for restart recovery."""
|
||||
|
||||
_locks_by_path: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def __init__(self, state_path: Optional[str] = None) -> None:
|
||||
self._state_path = os.path.abspath(state_path or get_aria2_state_path())
|
||||
self._lock = self._locks_by_path.setdefault(self._state_path, asyncio.Lock())
|
||||
|
||||
def _read_all_unlocked(self) -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
with open(self._state_path, "r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
|
||||
normalized: Dict[str, Dict[str, Any]] = {}
|
||||
for download_id, entry in data.items():
|
||||
if isinstance(download_id, str) and isinstance(entry, dict):
|
||||
normalized[download_id] = entry
|
||||
return normalized
|
||||
|
||||
def _write_all_unlocked(self, data: Dict[str, Dict[str, Any]]) -> None:
|
||||
directory = os.path.dirname(self._state_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
temp_path = f"{self._state_path}.tmp"
|
||||
with open(temp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(data, handle, ensure_ascii=True, indent=2, sort_keys=True)
|
||||
os.replace(temp_path, self._state_path)
|
||||
|
||||
async def load_all(self) -> Dict[str, Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return deepcopy(self._read_all_unlocked())
|
||||
|
||||
async def get(self, download_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
return deepcopy(self._read_all_unlocked().get(download_id))
|
||||
|
||||
async def upsert(self, download_id: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
current = data.get(download_id, {})
|
||||
current.update(payload)
|
||||
data[download_id] = current
|
||||
self._write_all_unlocked(data)
|
||||
return deepcopy(current)
|
||||
|
||||
async def remove(self, download_id: str) -> None:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
if download_id in data:
|
||||
del data[download_id]
|
||||
self._write_all_unlocked(data)
|
||||
|
||||
async def find_by_save_path(
|
||||
self, save_path: str, *, exclude_download_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
normalized_target = os.path.abspath(save_path)
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
for download_id, entry in data.items():
|
||||
if exclude_download_id and download_id == exclude_download_id:
|
||||
continue
|
||||
candidate = entry.get("save_path")
|
||||
if isinstance(candidate, str) and os.path.abspath(candidate) == normalized_target:
|
||||
result = dict(entry)
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
return None
|
||||
|
||||
async def reassign(self, from_download_id: str, to_download_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._lock:
|
||||
data = self._read_all_unlocked()
|
||||
existing = data.get(from_download_id)
|
||||
if existing is None:
|
||||
return None
|
||||
updated = dict(existing)
|
||||
updated["download_id"] = to_download_id
|
||||
data[to_download_id] = updated
|
||||
if from_download_id != to_download_id:
|
||||
data.pop(from_download_id, None)
|
||||
self._write_all_unlocked(data)
|
||||
return deepcopy(updated)
|
||||
139
py/services/auto_tag_service.py
Normal file
139
py/services/auto_tag_service.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Auto-tag extraction service for model cards.
|
||||
|
||||
Extracts implicit model attributes (HIGH/LOW, I2V/T2V/TI2V, Lightning, Turbo)
|
||||
from filename, base_model, and CivitAI version name — no manual tagging required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Set
|
||||
|
||||
# ── Tag category definitions ──────────────────────────────────────────
|
||||
# Each category maps a display label to a regex pattern.
|
||||
# Patterns are case-insensitive and matched against filename, base_model,
|
||||
# and civitai version name.
|
||||
|
||||
# Use (?<![a-zA-Z0-9]) and (?![a-zA-Z0-9]) instead of \b because
|
||||
# Python's \b treats underscore as a word character, so \bHIGH\b
|
||||
# won't match '_HIGH_' in filenames.
|
||||
_B = r"(?<![a-zA-Z0-9])" # left boundary
|
||||
_E = r"(?![a-zA-Z0-9])" # right boundary
|
||||
|
||||
AUTO_TAG_CATEGORIES: Dict[str, str] = {
|
||||
"HIGH": _B + r"HIGH" + _E,
|
||||
"LOW": _B + r"(?<!F)LOW" + _E,
|
||||
"I2V": _B + r"I2V" + _E,
|
||||
"T2V": _B + r"T2V" + _E,
|
||||
"TI2V": _B + r"TI2V" + _E,
|
||||
"Lightning": _B + r"Lightning" + _E,
|
||||
"Turbo": _B + r"Turbo" + _E,
|
||||
}
|
||||
|
||||
# Tags that belong to the "mode" group (HIGH/LOW)
|
||||
MODE_TAGS = {"HIGH", "LOW"}
|
||||
|
||||
# Tags that belong to the "video mode" group (I2V/T2V/TI2V)
|
||||
VIDEO_MODE_TAGS = {"I2V", "T2V", "TI2V"}
|
||||
|
||||
# Tags that belong to the "speed/optimization" group
|
||||
SPEED_TAGS = {"Lightning", "Turbo"}
|
||||
|
||||
# ── Display category groups (for settings UI) ─────────────────────────
|
||||
|
||||
AUTO_TAG_GROUPS = {
|
||||
"mode": {"HIGH", "LOW"},
|
||||
"video": {"I2V", "T2V", "TI2V"},
|
||||
"speed": {"Lightning", "Turbo"},
|
||||
}
|
||||
|
||||
# Default enabled categories
|
||||
DEFAULT_ENABLED_GROUPS = {"mode", "video"}
|
||||
|
||||
|
||||
def _collect_sources(model_data: Dict) -> List[str]:
|
||||
"""Collect all text sources from model data for tag matching."""
|
||||
sources: List[str] = []
|
||||
|
||||
file_name = model_data.get("file_name", "")
|
||||
if file_name:
|
||||
sources.append(file_name)
|
||||
|
||||
base_model = model_data.get("base_model", "")
|
||||
if base_model:
|
||||
sources.append(base_model)
|
||||
|
||||
civitai = model_data.get("civitai", {})
|
||||
if isinstance(civitai, dict):
|
||||
version_name = civitai.get("name", "")
|
||||
if version_name:
|
||||
sources.append(version_name)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def extract_auto_tags(model_data: Dict) -> List[str]:
|
||||
"""Extract auto-detected tags from model metadata.
|
||||
|
||||
Uses a two-layer approach:
|
||||
Layer 1 — Regex-based detection against filename, base_model, and
|
||||
CivitAI version name.
|
||||
Layer 2 — Merge in any user-defined tags that overlap with known
|
||||
auto-tag categories. This provides a manual fallback when
|
||||
auto-detection fails (e.g. "I2V HN" or unlabeled models).
|
||||
|
||||
HIGH/LOW tags are only returned when the base_model indicates a Wan
|
||||
family model — no other model architecture uses this distinction.
|
||||
|
||||
Args:
|
||||
model_data: Model metadata dict with keys:
|
||||
file_name, base_model, civitai (with optional 'name' field),
|
||||
tags (user-defined tag list, used as fallback).
|
||||
|
||||
Returns:
|
||||
Sorted list of unique auto-tag strings (e.g. ["I2V"]).
|
||||
"""
|
||||
sources = _collect_sources(model_data)
|
||||
base_model = model_data.get("base_model", "")
|
||||
is_wan = "wan" in base_model.lower()
|
||||
|
||||
found: Set[str] = set()
|
||||
|
||||
# ── Layer 1: regex-based detection ────────────────────────────
|
||||
if sources:
|
||||
for label, pattern in AUTO_TAG_CATEGORIES.items():
|
||||
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
|
||||
if label in ("HIGH", "LOW"):
|
||||
if not is_wan:
|
||||
continue
|
||||
# Use case-insensitive character class + case-sensitive boundary,
|
||||
# so "HighNoise" (camelCase) matches but "highlight" doesn't.
|
||||
# Boundary: not followed by lowercase letter (= word has ended).
|
||||
ci = "".join(f"[{c.lower()}{c.upper()}]" for c in label)
|
||||
if label == "LOW":
|
||||
regex = re.compile(r"(?<![Ff])" + ci + r"(?![a-z])")
|
||||
else:
|
||||
regex = re.compile(ci + r"(?![a-z])")
|
||||
else:
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
for source in sources:
|
||||
if regex.search(source):
|
||||
found.add(label)
|
||||
break
|
||||
|
||||
# ── Layer 2: user-defined tags as manual fallback ─────────────
|
||||
# When auto-detection fails (abbreviated names like "Hi"/"Lo",
|
||||
# "I2V HN", or unlabeled models), users can add canonical tags
|
||||
# (HIGH, LOW, I2V, etc.) to the model's regular tags for correct
|
||||
# badge display and filtering. Matching is case-insensitive so
|
||||
# "high"/"High"/"HIGH" all resolve to the canonical label.
|
||||
user_tags = model_data.get("tags")
|
||||
if user_tags:
|
||||
label_map = {label.lower(): label for label in AUTO_TAG_CATEGORIES}
|
||||
for t in user_tags:
|
||||
canonical = label_map.get(t.lower())
|
||||
if canonical:
|
||||
found.add(canonical)
|
||||
|
||||
return sorted(found)
|
||||
423
py/services/backup_service.py
Normal file
423
py/services/backup_service.py
Normal file
@@ -0,0 +1,423 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from ..utils.cache_paths import CacheType, get_cache_base_dir, get_cache_file_path
|
||||
from ..utils.settings_paths import get_settings_dir
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BACKUP_MANIFEST_VERSION = 1
|
||||
DEFAULT_BACKUP_RETENTION_COUNT = 5
|
||||
DEFAULT_BACKUP_INTERVAL_SECONDS = 24 * 60 * 60
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BackupEntry:
|
||||
kind: str
|
||||
archive_path: str
|
||||
target_path: str
|
||||
sha256: str
|
||||
size: int
|
||||
mtime: float
|
||||
|
||||
|
||||
class BackupService:
|
||||
"""Create and restore user-state backup archives."""
|
||||
|
||||
_instance: "BackupService | None" = None
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, *, settings_manager=None, backup_dir: str | None = None) -> None:
|
||||
self._settings = settings_manager or get_settings_manager()
|
||||
self._backup_dir = Path(backup_dir or self._resolve_backup_dir())
|
||||
self._backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._lock = asyncio.Lock()
|
||||
self._auto_task: asyncio.Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "BackupService":
|
||||
async with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
cls._instance._ensure_auto_snapshot_task()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
def _resolve_backup_dir() -> str:
|
||||
return os.path.join(get_settings_dir(create=True), "backups")
|
||||
|
||||
def get_backup_dir(self) -> str:
|
||||
return str(self._backup_dir)
|
||||
|
||||
def _ensure_auto_snapshot_task(self) -> None:
|
||||
if self._auto_task is not None and not self._auto_task.done():
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
self._auto_task = loop.create_task(self._auto_backup_loop())
|
||||
|
||||
def _get_setting_bool(self, key: str, default: bool) -> bool:
|
||||
try:
|
||||
return bool(self._settings.get(key, default))
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _get_setting_int(self, key: str, default: int) -> int:
|
||||
try:
|
||||
value = self._settings.get(key, default)
|
||||
return max(1, int(value))
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _settings_file_path(self) -> str:
|
||||
settings_file = getattr(self._settings, "settings_file", None)
|
||||
if settings_file:
|
||||
return str(settings_file)
|
||||
return os.path.join(get_settings_dir(create=True), "settings.json")
|
||||
|
||||
def _download_history_path(self) -> str:
|
||||
base_dir = get_cache_base_dir(create=True)
|
||||
history_dir = os.path.join(base_dir, "download_history")
|
||||
os.makedirs(history_dir, exist_ok=True)
|
||||
return os.path.join(history_dir, "downloaded_versions.sqlite")
|
||||
|
||||
def _model_update_dir(self) -> str:
|
||||
return str(Path(get_cache_file_path(CacheType.MODEL_UPDATE, create_dir=True)).parent)
|
||||
|
||||
def _model_update_targets(self) -> list[tuple[str, str, str]]:
|
||||
"""Return (kind, archive_path, target_path) tuples for backup."""
|
||||
|
||||
targets: list[tuple[str, str, str]] = []
|
||||
|
||||
settings_path = self._settings_file_path()
|
||||
targets.append(("settings", "settings/settings.json", settings_path))
|
||||
|
||||
history_path = self._download_history_path()
|
||||
targets.append(
|
||||
(
|
||||
"download_history",
|
||||
"cache/download_history/downloaded_versions.sqlite",
|
||||
history_path,
|
||||
)
|
||||
)
|
||||
|
||||
symlink_path = get_cache_file_path(CacheType.SYMLINK, create_dir=True)
|
||||
targets.append(
|
||||
(
|
||||
"symlink_map",
|
||||
"cache/symlink/symlink_map.json",
|
||||
symlink_path,
|
||||
)
|
||||
)
|
||||
|
||||
model_update_dir = Path(self._model_update_dir())
|
||||
if model_update_dir.exists():
|
||||
for sqlite_file in sorted(model_update_dir.glob("*.sqlite")):
|
||||
targets.append(
|
||||
(
|
||||
"model_update",
|
||||
f"cache/model_update/{sqlite_file.name}",
|
||||
str(sqlite_file),
|
||||
)
|
||||
)
|
||||
|
||||
stats_path = os.path.join(get_settings_dir(create=True), "stats", "lora_manager_stats.json")
|
||||
if os.path.exists(stats_path):
|
||||
targets.append(
|
||||
(
|
||||
"usage_stats",
|
||||
"stats/lora_manager_stats.json",
|
||||
stats_path,
|
||||
)
|
||||
)
|
||||
|
||||
return targets
|
||||
|
||||
@staticmethod
|
||||
def _hash_file(path: str) -> tuple[str, int, float]:
|
||||
digest = hashlib.sha256()
|
||||
total = 0
|
||||
with open(path, "rb") as handle:
|
||||
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
|
||||
total += len(chunk)
|
||||
digest.update(chunk)
|
||||
mtime = os.path.getmtime(path)
|
||||
return digest.hexdigest(), total, mtime
|
||||
|
||||
def _build_manifest(self, entries: Iterable[BackupEntry], *, snapshot_type: str) -> dict[str, Any]:
|
||||
created_at = datetime.now(timezone.utc).isoformat()
|
||||
active_library = None
|
||||
try:
|
||||
active_library = self._settings.get_active_library_name()
|
||||
except Exception:
|
||||
active_library = None
|
||||
|
||||
return {
|
||||
"manifest_version": BACKUP_MANIFEST_VERSION,
|
||||
"created_at": created_at,
|
||||
"snapshot_type": snapshot_type,
|
||||
"active_library": active_library,
|
||||
"files": [
|
||||
{
|
||||
"kind": entry.kind,
|
||||
"archive_path": entry.archive_path,
|
||||
"target_path": entry.target_path,
|
||||
"sha256": entry.sha256,
|
||||
"size": entry.size,
|
||||
"mtime": entry.mtime,
|
||||
}
|
||||
for entry in entries
|
||||
],
|
||||
}
|
||||
|
||||
def _write_archive(self, archive_path: str, entries: list[BackupEntry], manifest: dict[str, Any]) -> None:
|
||||
with zipfile.ZipFile(
|
||||
archive_path,
|
||||
mode="w",
|
||||
compression=zipfile.ZIP_DEFLATED,
|
||||
compresslevel=6,
|
||||
) as zf:
|
||||
zf.writestr(
|
||||
"manifest.json",
|
||||
json.dumps(manifest, indent=2, ensure_ascii=False).encode("utf-8"),
|
||||
)
|
||||
for entry in entries:
|
||||
zf.write(entry.target_path, arcname=entry.archive_path)
|
||||
|
||||
async def create_snapshot(self, *, snapshot_type: str = "manual", persist: bool = False) -> dict[str, Any]:
|
||||
"""Create a backup archive.
|
||||
|
||||
If ``persist`` is true, the archive is stored in the backup directory
|
||||
and retained according to the configured retention policy.
|
||||
"""
|
||||
|
||||
async with self._lock:
|
||||
raw_targets = self._model_update_targets()
|
||||
entries: list[BackupEntry] = []
|
||||
for kind, archive_path, target_path in raw_targets:
|
||||
if not os.path.exists(target_path):
|
||||
continue
|
||||
sha256, size, mtime = self._hash_file(target_path)
|
||||
entries.append(
|
||||
BackupEntry(
|
||||
kind=kind,
|
||||
archive_path=archive_path,
|
||||
target_path=target_path,
|
||||
sha256=sha256,
|
||||
size=size,
|
||||
mtime=mtime,
|
||||
)
|
||||
)
|
||||
|
||||
if not entries:
|
||||
raise FileNotFoundError("No backupable files were found")
|
||||
|
||||
manifest = self._build_manifest(entries, snapshot_type=snapshot_type)
|
||||
archive_name = self._build_archive_name(snapshot_type=snapshot_type)
|
||||
fd, temp_path = tempfile.mkstemp(suffix=".zip", dir=str(self._backup_dir))
|
||||
os.close(fd)
|
||||
|
||||
try:
|
||||
self._write_archive(temp_path, entries, manifest)
|
||||
if persist:
|
||||
final_path = self._backup_dir / archive_name
|
||||
os.replace(temp_path, final_path)
|
||||
self._prune_snapshots()
|
||||
return {
|
||||
"archive_path": str(final_path),
|
||||
"archive_name": final_path.name,
|
||||
"manifest": manifest,
|
||||
}
|
||||
|
||||
with open(temp_path, "rb") as handle:
|
||||
data = handle.read()
|
||||
return {
|
||||
"archive_name": archive_name,
|
||||
"archive_bytes": data,
|
||||
"manifest": manifest,
|
||||
}
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
os.remove(temp_path)
|
||||
|
||||
def _build_archive_name(self, *, snapshot_type: str) -> str:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
return f"lora-manager-backup-{timestamp}-{snapshot_type}.zip"
|
||||
|
||||
def _prune_snapshots(self) -> None:
|
||||
retention = self._get_setting_int(
|
||||
"backup_retention_count", DEFAULT_BACKUP_RETENTION_COUNT
|
||||
)
|
||||
archives = sorted(
|
||||
self._backup_dir.glob("lora-manager-backup-*-auto.zip"),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
for path in archives[retention:]:
|
||||
with contextlib.suppress(OSError):
|
||||
path.unlink()
|
||||
|
||||
async def restore_snapshot(self, archive_path: str) -> dict[str, Any]:
|
||||
"""Restore backup contents from a ZIP archive."""
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
zf = zipfile.ZipFile(archive_path, mode="r")
|
||||
except zipfile.BadZipFile as exc:
|
||||
raise ValueError("Backup archive is not a valid ZIP file") from exc
|
||||
|
||||
with zf:
|
||||
try:
|
||||
manifest = json.loads(zf.read("manifest.json").decode("utf-8"))
|
||||
except KeyError as exc:
|
||||
raise ValueError("Backup archive is missing manifest.json") from exc
|
||||
|
||||
if not isinstance(manifest, dict):
|
||||
raise ValueError("Backup manifest is invalid")
|
||||
if manifest.get("manifest_version") != BACKUP_MANIFEST_VERSION:
|
||||
raise ValueError("Backup manifest version is not supported")
|
||||
|
||||
files = manifest.get("files", [])
|
||||
if not isinstance(files, list):
|
||||
raise ValueError("Backup manifest file list is invalid")
|
||||
|
||||
extracted_paths: list[tuple[str, str]] = []
|
||||
temp_dir = Path(tempfile.mkdtemp(prefix="lora-manager-restore-"))
|
||||
try:
|
||||
for item in files:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
archive_member = item.get("archive_path")
|
||||
if not isinstance(archive_member, str) or not archive_member:
|
||||
continue
|
||||
archive_member_path = Path(archive_member)
|
||||
if archive_member_path.is_absolute() or ".." in archive_member_path.parts:
|
||||
raise ValueError(f"Invalid archive member path: {archive_member}")
|
||||
|
||||
kind = item.get("kind")
|
||||
target_path = self._resolve_restore_target(kind, archive_member)
|
||||
if target_path is None:
|
||||
continue
|
||||
|
||||
extracted_path = temp_dir / archive_member_path
|
||||
extracted_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(archive_member) as source, open(
|
||||
extracted_path, "wb"
|
||||
) as destination:
|
||||
shutil.copyfileobj(source, destination)
|
||||
|
||||
expected_hash = item.get("sha256")
|
||||
if isinstance(expected_hash, str) and expected_hash:
|
||||
actual_hash, _, _ = self._hash_file(str(extracted_path))
|
||||
if actual_hash != expected_hash:
|
||||
raise ValueError(
|
||||
f"Checksum mismatch for {archive_member}"
|
||||
)
|
||||
|
||||
extracted_paths.append((str(extracted_path), target_path))
|
||||
|
||||
for extracted_path, target_path in extracted_paths:
|
||||
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
||||
os.replace(extracted_path, target_path)
|
||||
finally:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"restored_files": len(extracted_paths),
|
||||
"snapshot_type": manifest.get("snapshot_type"),
|
||||
}
|
||||
|
||||
def _resolve_restore_target(self, kind: Any, archive_member: str) -> str | None:
|
||||
if kind == "settings":
|
||||
return self._settings_file_path()
|
||||
if kind == "download_history":
|
||||
return self._download_history_path()
|
||||
if kind == "symlink_map":
|
||||
return get_cache_file_path(CacheType.SYMLINK, create_dir=True)
|
||||
if kind == "model_update":
|
||||
filename = os.path.basename(archive_member)
|
||||
return str(Path(get_cache_file_path(CacheType.MODEL_UPDATE, create_dir=True)).parent / filename)
|
||||
if kind == "usage_stats":
|
||||
return os.path.join(get_settings_dir(create=True), "stats", "lora_manager_stats.json")
|
||||
return None
|
||||
|
||||
async def create_auto_snapshot_if_due(self) -> Optional[dict[str, Any]]:
|
||||
if not self._get_setting_bool("backup_auto_enabled", True):
|
||||
return None
|
||||
|
||||
latest = self.get_latest_auto_snapshot()
|
||||
now = time.time()
|
||||
if latest and now - latest["mtime"] < DEFAULT_BACKUP_INTERVAL_SECONDS:
|
||||
return None
|
||||
|
||||
return await self.create_snapshot(snapshot_type="auto", persist=True)
|
||||
|
||||
async def _auto_backup_loop(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
await self.create_auto_snapshot_if_due()
|
||||
await asyncio.sleep(DEFAULT_BACKUP_INTERVAL_SECONDS)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.warning("Automatic backup snapshot failed: %s", exc, exc_info=True)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def get_available_snapshots(self) -> list[dict[str, Any]]:
|
||||
snapshots: list[dict[str, Any]] = []
|
||||
for path in sorted(self._backup_dir.glob("lora-manager-backup-*.zip")):
|
||||
try:
|
||||
stat = path.stat()
|
||||
except OSError:
|
||||
continue
|
||||
snapshots.append(
|
||||
{
|
||||
"name": path.name,
|
||||
"path": str(path),
|
||||
"size": stat.st_size,
|
||||
"mtime": stat.st_mtime,
|
||||
"is_auto": path.name.endswith("-auto.zip"),
|
||||
}
|
||||
)
|
||||
snapshots.sort(key=lambda item: item["mtime"], reverse=True)
|
||||
return snapshots
|
||||
|
||||
def get_latest_auto_snapshot(self) -> Optional[dict[str, Any]]:
|
||||
autos = [snapshot for snapshot in self.get_available_snapshots() if snapshot["is_auto"]]
|
||||
if not autos:
|
||||
return None
|
||||
return autos[0]
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
snapshots = self.get_available_snapshots()
|
||||
return {
|
||||
"backupDir": self.get_backup_dir(),
|
||||
"enabled": self._get_setting_bool("backup_auto_enabled", True),
|
||||
"retentionCount": self._get_setting_int(
|
||||
"backup_retention_count", DEFAULT_BACKUP_RETENTION_COUNT
|
||||
),
|
||||
"snapshotCount": len(snapshots),
|
||||
"latestSnapshot": snapshots[0] if snapshots else None,
|
||||
"latestAutoSnapshot": self.get_latest_auto_snapshot(),
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
import logging
|
||||
import os
|
||||
@@ -19,6 +20,7 @@ from .model_query import (
|
||||
resolve_sub_type,
|
||||
)
|
||||
from .settings_manager import get_settings_manager
|
||||
from ..utils.civitai_utils import build_civitai_model_page_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,6 +77,7 @@ class BaseModelService(ABC):
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
auto_tags: Optional[Dict[str, str]] = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
@@ -93,9 +96,108 @@ class BaseModelService(ABC):
|
||||
sorted_data = await self._fetch_with_usage_sort(sort_params)
|
||||
else:
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
# Pre-compute auto_tags for every item — needed for both filtering
|
||||
# and display. Computation is cheap (string regex on 2-3 fields).
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
for item in sorted_data:
|
||||
item["auto_tags"] = extract_auto_tags(item)
|
||||
fetch_duration = time.perf_counter() - t0
|
||||
initial_count = len(sorted_data)
|
||||
|
||||
# Optionally filter by civitai model ID (shows all local versions of a specific model)
|
||||
civitai_model_id = kwargs.get("civitai_model_id")
|
||||
if civitai_model_id is not None:
|
||||
sorted_data = [
|
||||
item for item in sorted_data
|
||||
if self._extract_model_id(item) == civitai_model_id
|
||||
]
|
||||
# VLM mode: always sort by version ID descending (newest version first),
|
||||
# regardless of the current sort_by preference.
|
||||
sorted_data.sort(
|
||||
key=lambda x: self._extract_version_id(x) or 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Optionally group by civitai modelId, showing only the latest version per model
|
||||
dedup_lost = 0
|
||||
if kwargs.get("group_by_model") and civitai_model_id is None:
|
||||
# Determine whether to further sub-group by base model
|
||||
# When version_grouping is "same_base", versions with different
|
||||
# base models are effectively different groups — the dedup key
|
||||
# needs to include base_model so the version count and VLM flow
|
||||
# stay consistent (card shows correct count for its base model).
|
||||
ufs = self.settings.get("version_grouping", "same_base")
|
||||
group_by_base = ufs == "same_base"
|
||||
|
||||
dedup_map = {} # (modelId [,base_model]) -> (item, version_id)
|
||||
version_counter = {} # same-key -> count
|
||||
standalone = []
|
||||
for item in sorted_data:
|
||||
mid = self._extract_model_id(item)
|
||||
if mid is None:
|
||||
standalone.append(item)
|
||||
continue
|
||||
key = (mid, item.get("base_model") or "") if group_by_base else mid
|
||||
# Count all versions per key
|
||||
version_counter[key] = version_counter.get(key, 0) + 1
|
||||
vid = self._extract_version_id(item) or 0
|
||||
if key not in dedup_map or vid > dedup_map[key][1]:
|
||||
dedup_map[key] = (item, vid)
|
||||
# Attach version_count to each surviving grouped item (shallow copy
|
||||
# to avoid mutating cached dicts — the cache is shared across requests)
|
||||
for key, (item, vid) in dedup_map.items():
|
||||
item = dict(item)
|
||||
item["version_count"] = version_counter[key]
|
||||
dedup_map[key] = (item, vid)
|
||||
dedup_lost = len(sorted_data) - (len(dedup_map) + len(standalone))
|
||||
sorted_data = [entry[0] for entry in dedup_map.values()] + standalone
|
||||
|
||||
# Re-sort by version_count (grouped: after dedup; non-grouped: group internally, sort, expand)
|
||||
if sort_params.key == "versions_count" and civitai_model_id is None:
|
||||
reverse = sort_params.order == "desc"
|
||||
if kwargs.get("group_by_model"):
|
||||
# Grouped mode: items are already dedup'd with version_count attached
|
||||
sorted_data.sort(
|
||||
key=lambda x: (
|
||||
x.get("version_count", 0),
|
||||
(x.get("model_name") or x.get("file_name") or "").lower(),
|
||||
x.get("file_path", "").lower(),
|
||||
),
|
||||
reverse=reverse,
|
||||
)
|
||||
else:
|
||||
# Non-grouped mode: group internally, sort groups by count, expand
|
||||
# Respect the version_grouping setting (same logic as grouped dedup)
|
||||
ufs = self.settings.get("version_grouping", "same_base")
|
||||
group_by_base = ufs == "same_base"
|
||||
|
||||
model_groups: Dict[Any, List[Dict]] = {}
|
||||
ungrouped_standalone: List[Dict] = []
|
||||
for item in sorted_data:
|
||||
mid = self._extract_model_id(item)
|
||||
if mid is None:
|
||||
ungrouped_standalone.append(item)
|
||||
continue
|
||||
key = (mid, item.get("base_model") or "") if group_by_base else mid
|
||||
model_groups.setdefault(key, []).append(item)
|
||||
# Sort versions within each group by version id descending
|
||||
for items in model_groups.values():
|
||||
items.sort(
|
||||
key=lambda x: self._extract_version_id(x) or 0,
|
||||
reverse=True,
|
||||
)
|
||||
# Sort groups by version count
|
||||
sorted_groups = sorted(
|
||||
model_groups.values(),
|
||||
key=lambda items: len(items),
|
||||
reverse=reverse,
|
||||
)
|
||||
# Flatten: grouped items first, standalone items last
|
||||
sorted_data = []
|
||||
for items in sorted_groups:
|
||||
sorted_data.extend(items)
|
||||
sorted_data.extend(ungrouped_standalone)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||
@@ -108,6 +210,7 @@ class BaseModelService(ABC):
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
auto_tags=auto_tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
tag_logic=tag_logic,
|
||||
@@ -163,7 +266,7 @@ class BaseModelService(ABC):
|
||||
overall_duration = time.perf_counter() - overall_start
|
||||
logger.debug(
|
||||
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
|
||||
"Counts: initial=%d, post_filter=%d, final=%d",
|
||||
"Counts: initial=%d, dedup=%d, post_filter=%d, final=%d",
|
||||
self.__class__.__name__,
|
||||
overall_duration,
|
||||
fetch_duration,
|
||||
@@ -172,11 +275,63 @@ class BaseModelService(ABC):
|
||||
pagination_duration,
|
||||
annotate_duration,
|
||||
initial_count,
|
||||
dedup_lost,
|
||||
post_filter_count,
|
||||
final_count,
|
||||
)
|
||||
return paginated
|
||||
|
||||
async def get_excluded_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_by: str = "name",
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
search_options: dict = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated excluded model data."""
|
||||
excluded_paths = list(self.scanner.get_excluded_models())
|
||||
excluded_entries: List[Dict[str, Any]] = []
|
||||
stale_paths: List[str] = []
|
||||
|
||||
for file_path in excluded_paths:
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
stale_paths.append(file_path)
|
||||
continue
|
||||
|
||||
entry = await self._build_excluded_entry(file_path)
|
||||
if entry:
|
||||
excluded_entries.append(entry)
|
||||
else:
|
||||
stale_paths.append(file_path)
|
||||
|
||||
if stale_paths:
|
||||
current_excluded = getattr(self.scanner, "_excluded_models", None)
|
||||
if isinstance(current_excluded, list):
|
||||
stale_set = set(stale_paths)
|
||||
self.scanner._excluded_models = [
|
||||
path for path in current_excluded if path not in stale_set
|
||||
]
|
||||
persist_current_cache = getattr(self.scanner, "_persist_current_cache", None)
|
||||
if callable(persist_current_cache):
|
||||
await persist_current_cache()
|
||||
|
||||
excluded_entries = self._sort_entries(excluded_entries, sort_by)
|
||||
|
||||
if search:
|
||||
excluded_entries = await self._apply_search_filters(
|
||||
excluded_entries,
|
||||
search,
|
||||
fuzzy_search,
|
||||
search_options,
|
||||
)
|
||||
|
||||
paginated = self._paginate(excluded_entries, page, page_size)
|
||||
paginated["items"] = await self._annotate_update_flags(paginated["items"])
|
||||
return paginated
|
||||
|
||||
async def _fetch_with_usage_sort(self, sort_params):
|
||||
"""Fetch data sorted by usage count (desc/asc)."""
|
||||
cache = await self.cache_repository.get_cache()
|
||||
@@ -207,11 +362,71 @@ class BaseModelService(ABC):
|
||||
|
||||
reverse = sort_params.order == "desc"
|
||||
annotated.sort(
|
||||
key=lambda x: (x.get("usage_count", 0), x.get("model_name", "").lower()),
|
||||
key=lambda x: (
|
||||
x.get("usage_count", 0),
|
||||
x.get("model_name", "").lower(),
|
||||
x.get("file_path", "").lower()
|
||||
),
|
||||
reverse=reverse,
|
||||
)
|
||||
return annotated
|
||||
|
||||
def _sort_entries(self, data: List[Dict[str, Any]], sort_by: str) -> List[Dict[str, Any]]:
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
key_name = sort_params.key
|
||||
|
||||
if key_name == "date":
|
||||
key_fn = lambda item: (
|
||||
float(item.get("modified", 0.0) or 0.0),
|
||||
(item.get("model_name") or item.get("file_name") or "").lower(),
|
||||
item.get("file_path", "").lower(),
|
||||
)
|
||||
elif key_name == "size":
|
||||
key_fn = lambda item: (
|
||||
int(item.get("size", 0) or 0),
|
||||
(item.get("model_name") or item.get("file_name") or "").lower(),
|
||||
item.get("file_path", "").lower(),
|
||||
)
|
||||
elif key_name == "usage":
|
||||
key_fn = lambda item: (
|
||||
int(item.get("usage_count", 0) or 0),
|
||||
(item.get("model_name") or item.get("file_name") or "").lower(),
|
||||
item.get("file_path", "").lower(),
|
||||
)
|
||||
else:
|
||||
key_fn = lambda item: (
|
||||
(item.get("model_name") or item.get("file_name") or "").lower(),
|
||||
item.get("file_path", "").lower(),
|
||||
)
|
||||
|
||||
return sorted(data, key=key_fn, reverse=sort_params.order == "desc")
|
||||
|
||||
async def _build_excluded_entry(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
root_path = self.scanner._find_root_for_file(file_path)
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path,
|
||||
self.metadata_class,
|
||||
)
|
||||
if should_skip:
|
||||
return None
|
||||
|
||||
if metadata is None:
|
||||
metadata = await self.scanner._create_default_metadata(file_path)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
metadata = self.scanner.adjust_metadata(metadata, file_path, root_path)
|
||||
folder = os.path.dirname(os.path.relpath(file_path, root_path)).replace(
|
||||
os.path.sep, "/"
|
||||
)
|
||||
entry = self.scanner._build_cache_entry(metadata, folder=folder)
|
||||
entry = self.scanner.adjust_cached_entry(entry)
|
||||
entry["exclude"] = True
|
||||
return entry
|
||||
|
||||
async def _apply_hash_filters(
|
||||
self, data: List[Dict], hash_filters: Dict
|
||||
) -> List[Dict]:
|
||||
@@ -241,6 +456,7 @@ class BaseModelService(ABC):
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
auto_tags: Optional[Dict[str, str]] = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
tag_logic: str = "any",
|
||||
@@ -254,6 +470,7 @@ class BaseModelService(ABC):
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
auto_tags=auto_tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
tag_logic=tag_logic,
|
||||
@@ -373,7 +590,7 @@ class BaseModelService(ABC):
|
||||
if not ordered_ids:
|
||||
return annotated
|
||||
|
||||
strategy_value = self.settings.get("update_flag_strategy")
|
||||
strategy_value = self.settings.get("version_grouping")
|
||||
if isinstance(strategy_value, str) and strategy_value.strip():
|
||||
strategy = strategy_value.strip().lower()
|
||||
else:
|
||||
@@ -383,7 +600,9 @@ class BaseModelService(ABC):
|
||||
# Check user setting for hiding early access updates
|
||||
hide_early_access = False
|
||||
try:
|
||||
hide_early_access = bool(self.settings.get("hide_early_access_updates", False))
|
||||
hide_early_access = bool(
|
||||
self.settings.get("hide_early_access_updates", False)
|
||||
)
|
||||
except Exception:
|
||||
hide_early_access = False
|
||||
|
||||
@@ -413,7 +632,11 @@ class BaseModelService(ABC):
|
||||
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
||||
if callable(bulk_method):
|
||||
try:
|
||||
resolved = await bulk_method(self.model_type, ordered_ids, hide_early_access=hide_early_access)
|
||||
resolved = await bulk_method(
|
||||
self.model_type,
|
||||
ordered_ids,
|
||||
hide_early_access=hide_early_access,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to resolve update status in bulk for %s models (%s): %s",
|
||||
@@ -426,7 +649,9 @@ class BaseModelService(ABC):
|
||||
|
||||
if resolved is None:
|
||||
tasks = [
|
||||
self.update_service.has_update(self.model_type, model_id, hide_early_access=hide_early_access)
|
||||
self.update_service.has_update(
|
||||
self.model_type, model_id, hide_early_access=hide_early_access
|
||||
)
|
||||
for model_id in ordered_ids
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@@ -566,8 +791,12 @@ class BaseModelService(ABC):
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
async def format_response(self, model_data: Dict) -> Optional[Dict]:
|
||||
"""Format model data for API response - must be implemented by subclasses.
|
||||
|
||||
Subclasses should return None for corrupted entries so the handler
|
||||
layer can filter them out. See issue #730.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
@@ -588,13 +817,19 @@ class BaseModelService(ABC):
|
||||
normalized_type = normalize_sub_type(resolve_sub_type(entry))
|
||||
if not normalized_type:
|
||||
continue
|
||||
|
||||
|
||||
# Filter by valid sub-types based on scanner type
|
||||
if self.model_type == "lora" and normalized_type not in VALID_LORA_SUB_TYPES:
|
||||
if (
|
||||
self.model_type == "lora"
|
||||
and normalized_type not in VALID_LORA_SUB_TYPES
|
||||
):
|
||||
continue
|
||||
if self.model_type == "checkpoint" and normalized_type not in VALID_CHECKPOINT_SUB_TYPES:
|
||||
if (
|
||||
self.model_type == "checkpoint"
|
||||
and normalized_type not in VALID_CHECKPOINT_SUB_TYPES
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
||||
|
||||
sorted_types = sorted(
|
||||
@@ -734,30 +969,86 @@ class BaseModelService(ABC):
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
name_normalized = model_name.replace("\\", "/")
|
||||
name_no_ext = name_normalized
|
||||
for ext in (".safetensors", ".ckpt", ".pt", ".bin"):
|
||||
if name_no_ext.lower().endswith(ext):
|
||||
name_no_ext = name_no_ext[: -len(ext)]
|
||||
break
|
||||
|
||||
has_path = "/" in name_no_ext
|
||||
basename = os.path.basename(name_no_ext) if has_path else name_no_ext
|
||||
best_fallback = None
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model["file_name"] == model_name:
|
||||
file_name = model.get("file_name", "")
|
||||
folder = model.get("folder", "")
|
||||
file_name_no_ext = file_name
|
||||
for ext in (".safetensors", ".ckpt", ".pt", ".bin"):
|
||||
if file_name_no_ext.lower().endswith(ext):
|
||||
file_name_no_ext = file_name_no_ext[: -len(ext)]
|
||||
break
|
||||
path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext
|
||||
|
||||
if name_no_ext == file_name_no_ext or name_no_ext == path_name:
|
||||
preview_url = model.get("preview_url")
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
if has_path and file_name_no_ext == basename:
|
||||
if folder and name_no_ext.startswith(folder.replace("\\", "/") + "/"):
|
||||
best_fallback = model
|
||||
elif best_fallback is None:
|
||||
best_fallback = model
|
||||
|
||||
if best_fallback:
|
||||
preview_url = best_fallback.get("preview_url")
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
name_normalized = model_name.replace("\\", "/")
|
||||
name_no_ext = name_normalized
|
||||
for ext in (".safetensors", ".ckpt", ".pt", ".bin"):
|
||||
if name_no_ext.lower().endswith(ext):
|
||||
name_no_ext = name_no_ext[: -len(ext)]
|
||||
break
|
||||
|
||||
has_path = "/" in name_no_ext
|
||||
basename = os.path.basename(name_no_ext) if has_path else name_no_ext
|
||||
best_fallback = None
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model["file_name"] == model_name:
|
||||
file_name = model.get("file_name", "")
|
||||
folder = model.get("folder", "")
|
||||
file_name_no_ext = file_name
|
||||
for ext in (".safetensors", ".ckpt", ".pt", ".bin"):
|
||||
if file_name_no_ext.lower().endswith(ext):
|
||||
file_name_no_ext = file_name_no_ext[: -len(ext)]
|
||||
break
|
||||
path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext
|
||||
|
||||
if name_no_ext == file_name_no_ext or name_no_ext == path_name:
|
||||
civitai_data = model.get("civitai", {})
|
||||
model_id = civitai_data.get("modelId")
|
||||
version_id = civitai_data.get("id")
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
civitai_host = self.settings.get("civitai_host", "civitai.com")
|
||||
civitai_url = build_civitai_model_page_url(
|
||||
model_id,
|
||||
version_id,
|
||||
host=civitai_host,
|
||||
)
|
||||
|
||||
return {
|
||||
"civitai_url": civitai_url,
|
||||
@@ -765,6 +1056,27 @@ class BaseModelService(ABC):
|
||||
"version_id": str(version_id) if version_id else None,
|
||||
}
|
||||
|
||||
if has_path and file_name_no_ext == basename:
|
||||
if folder and name_no_ext.startswith(folder.replace("\\", "/") + "/"):
|
||||
best_fallback = model
|
||||
elif best_fallback is None:
|
||||
best_fallback = model
|
||||
|
||||
if best_fallback:
|
||||
civitai_data = best_fallback.get("civitai", {})
|
||||
model_id = civitai_data.get("modelId")
|
||||
if model_id:
|
||||
version_id = civitai_data.get("id")
|
||||
civitai_host = self.settings.get("civitai_host", "civitai.com")
|
||||
civitai_url = build_civitai_model_page_url(
|
||||
model_id, version_id, host=civitai_host
|
||||
)
|
||||
return {
|
||||
"civitai_url": civitai_url,
|
||||
"model_id": str(model_id),
|
||||
"version_id": str(version_id) if version_id else None,
|
||||
}
|
||||
|
||||
return {"civitai_url": None, "model_id": None, "version_id": None}
|
||||
|
||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||
@@ -778,6 +1090,17 @@ class BaseModelService(ABC):
|
||||
)
|
||||
if should_skip or metadata is None:
|
||||
return None
|
||||
|
||||
# Prune stale example-image metadata entries whose files no longer
|
||||
# exist on disk (e.g. a user deleted the files manually).
|
||||
from ..utils.example_images_metadata import MetadataUpdater
|
||||
|
||||
was_modified = await MetadataUpdater.prune_stale_example_images(metadata)
|
||||
if was_modified:
|
||||
asyncio.create_task(
|
||||
MetadataManager.save_metadata(file_path, metadata)
|
||||
)
|
||||
|
||||
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
|
||||
|
||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||
@@ -807,38 +1130,61 @@ class BaseModelService(ABC):
|
||||
|
||||
return include_terms, exclude_terms
|
||||
|
||||
@staticmethod
|
||||
def _remove_model_extension(path: str) -> str:
|
||||
"""Remove model file extension (.safetensors, .ckpt, .pt, .bin) for cleaner matching."""
|
||||
return re.sub(r"\.(safetensors|ckpt|pt|bin)$", "", path, flags=re.IGNORECASE)
|
||||
|
||||
@staticmethod
|
||||
def _relative_path_matches_tokens(
|
||||
path_lower: str, include_terms: List[str], exclude_terms: List[str]
|
||||
) -> bool:
|
||||
"""Determine whether a relative path string satisfies include/exclude tokens."""
|
||||
if any(term and term in path_lower for term in exclude_terms):
|
||||
"""Determine whether a relative path string satisfies include/exclude tokens.
|
||||
|
||||
Matches against the path without extension to avoid matching .safetensors
|
||||
when searching for 's'.
|
||||
"""
|
||||
# Use path without extension for matching
|
||||
path_for_matching = BaseModelService._remove_model_extension(path_lower)
|
||||
|
||||
if any(term and term in path_for_matching for term in exclude_terms):
|
||||
return False
|
||||
|
||||
for term in include_terms:
|
||||
if term and term not in path_lower:
|
||||
if term and term not in path_for_matching:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
|
||||
"""Sort paths by how well they satisfy the include tokens."""
|
||||
path_lower = relative_path.lower()
|
||||
"""Sort paths by how well they satisfy the include tokens.
|
||||
|
||||
Sorts based on path without extension for consistent ordering.
|
||||
"""
|
||||
# Use path without extension for sorting
|
||||
path_for_sorting = BaseModelService._remove_model_extension(
|
||||
relative_path.lower()
|
||||
)
|
||||
prefix_hits = sum(
|
||||
1 for term in include_terms if term and path_lower.startswith(term)
|
||||
1 for term in include_terms if term and path_for_sorting.startswith(term)
|
||||
)
|
||||
match_positions = [
|
||||
path_lower.find(term)
|
||||
path_for_sorting.find(term)
|
||||
for term in include_terms
|
||||
if term and term in path_lower
|
||||
if term and term in path_for_sorting
|
||||
]
|
||||
first_match_index = min(match_positions) if match_positions else 0
|
||||
|
||||
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
|
||||
return (
|
||||
-prefix_hits,
|
||||
first_match_index,
|
||||
len(path_for_sorting),
|
||||
path_for_sorting,
|
||||
)
|
||||
|
||||
async def search_relative_paths(
|
||||
self, search_term: str, limit: int = 15
|
||||
self, search_term: str, limit: int = 15, offset: int = 0
|
||||
) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
@@ -849,6 +1195,7 @@ class BaseModelService(ABC):
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
# Collect all matching paths first (needed for proper sorting and offset)
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get("file_path", "")
|
||||
if not file_path:
|
||||
@@ -877,12 +1224,12 @@ class BaseModelService(ABC):
|
||||
):
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
# Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
|
||||
matching_paths.sort(
|
||||
key=lambda relative: self._relative_path_sort_key(relative, include_terms)
|
||||
)
|
||||
|
||||
return matching_paths[:limit]
|
||||
# Apply offset and limit
|
||||
start = min(offset, len(matching_paths))
|
||||
end = min(start + limit, len(matching_paths))
|
||||
return matching_paths[start:end]
|
||||
|
||||
597
py/services/batch_import_service.py
Normal file
597
py/services/batch_import_service.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Batch import service for importing multiple images as recipes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipePersistenceService,
|
||||
RecipeValidationError,
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class ImportItemType(Enum):
|
||||
"""Type of import item."""
|
||||
|
||||
URL = "url"
|
||||
LOCAL_PATH = "local_path"
|
||||
|
||||
|
||||
class ImportStatus(Enum):
|
||||
"""Status of an individual import item."""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchImportItem:
|
||||
"""Represents a single item to import."""
|
||||
|
||||
id: str
|
||||
source: str
|
||||
item_type: ImportItemType
|
||||
status: ImportStatus = ImportStatus.PENDING
|
||||
error_message: Optional[str] = None
|
||||
recipe_name: Optional[str] = None
|
||||
recipe_id: Optional[str] = None
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchImportProgress:
|
||||
"""Tracks progress of a batch import operation."""
|
||||
|
||||
operation_id: str
|
||||
total: int
|
||||
completed: int = 0
|
||||
success: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
current_item: str = ""
|
||||
status: str = "pending"
|
||||
started_at: float = field(default_factory=time.time)
|
||||
finished_at: Optional[float] = None
|
||||
items: List[BatchImportItem] = field(default_factory=list)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
skip_no_metadata: bool = False
|
||||
skip_duplicates: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"operation_id": self.operation_id,
|
||||
"total": self.total,
|
||||
"completed": self.completed,
|
||||
"success": self.success,
|
||||
"failed": self.failed,
|
||||
"skipped": self.skipped,
|
||||
"current_item": self.current_item,
|
||||
"status": self.status,
|
||||
"started_at": self.started_at,
|
||||
"finished_at": self.finished_at,
|
||||
"progress_percent": round((self.completed / self.total) * 100, 1)
|
||||
if self.total > 0
|
||||
else 0,
|
||||
"items": [
|
||||
{
|
||||
"id": item.id,
|
||||
"source": item.source,
|
||||
"item_type": item.item_type.value,
|
||||
"status": item.status.value,
|
||||
"error_message": item.error_message,
|
||||
"recipe_name": item.recipe_name,
|
||||
"recipe_id": item.recipe_id,
|
||||
"duration": item.duration,
|
||||
}
|
||||
for item in self.items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class AdaptiveConcurrencyController:
|
||||
"""Adjusts concurrency based on task performance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_concurrency: int = 1,
|
||||
max_concurrency: int = 5,
|
||||
initial_concurrency: int = 3,
|
||||
) -> None:
|
||||
self.min_concurrency = min_concurrency
|
||||
self.max_concurrency = max_concurrency
|
||||
self.current_concurrency = initial_concurrency
|
||||
self._task_durations: List[float] = []
|
||||
self._recent_errors = 0
|
||||
self._recent_successes = 0
|
||||
|
||||
def record_result(self, duration: float, success: bool) -> None:
|
||||
self._task_durations.append(duration)
|
||||
if len(self._task_durations) > 10:
|
||||
self._task_durations.pop(0)
|
||||
|
||||
if success:
|
||||
self._recent_successes += 1
|
||||
if duration < 1.0 and self.current_concurrency < self.max_concurrency:
|
||||
self.current_concurrency = min(
|
||||
self.current_concurrency + 1, self.max_concurrency
|
||||
)
|
||||
elif duration > 10.0 and self.current_concurrency > self.min_concurrency:
|
||||
self.current_concurrency = max(
|
||||
self.current_concurrency - 1, self.min_concurrency
|
||||
)
|
||||
else:
|
||||
self._recent_errors += 1
|
||||
if self.current_concurrency > self.min_concurrency:
|
||||
self.current_concurrency = max(
|
||||
self.current_concurrency - 1, self.min_concurrency
|
||||
)
|
||||
|
||||
def reset_counters(self) -> None:
|
||||
self._recent_errors = 0
|
||||
self._recent_successes = 0
|
||||
|
||||
def get_semaphore(self) -> asyncio.Semaphore:
|
||||
return asyncio.Semaphore(self.current_concurrency)
|
||||
|
||||
|
||||
class BatchImportService:
|
||||
"""Service for batch importing images as recipes."""
|
||||
|
||||
SUPPORTED_EXTENSIONS: Set[str] = {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
persistence_service: RecipePersistenceService,
|
||||
ws_manager: Any,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._analysis_service = analysis_service
|
||||
self._persistence_service = persistence_service
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._active_operations: Dict[str, BatchImportProgress] = {}
|
||||
self._cancellation_flags: Dict[str, bool] = {}
|
||||
self._concurrency_controller = AdaptiveConcurrencyController()
|
||||
|
||||
def is_import_running(self, operation_id: Optional[str] = None) -> bool:
|
||||
if operation_id:
|
||||
progress = self._active_operations.get(operation_id)
|
||||
return progress is not None and progress.status in ("pending", "running")
|
||||
return any(
|
||||
p.status in ("pending", "running") for p in self._active_operations.values()
|
||||
)
|
||||
|
||||
def get_progress(self, operation_id: str) -> Optional[BatchImportProgress]:
|
||||
return self._active_operations.get(operation_id)
|
||||
|
||||
def cancel_import(self, operation_id: str) -> bool:
|
||||
if operation_id in self._active_operations:
|
||||
self._cancellation_flags[operation_id] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def _validate_url(self, url: str) -> bool:
|
||||
import re
|
||||
|
||||
url_pattern = re.compile(
|
||||
r"^https?://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|"
|
||||
r"localhost|"
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"
|
||||
r"(?::\d+)?"
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return url_pattern.match(url) is not None
|
||||
|
||||
def _validate_local_path(self, path: str) -> bool:
|
||||
try:
|
||||
normalized = os.path.normpath(path)
|
||||
if not os.path.isabs(normalized):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_duplicate_source(
|
||||
self,
|
||||
source: str,
|
||||
item_type: ImportItemType,
|
||||
recipe_scanner: Any,
|
||||
) -> bool:
|
||||
try:
|
||||
cache = recipe_scanner.get_cached_data_sync()
|
||||
if not cache:
|
||||
return False
|
||||
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
source_path = recipe.get("source_path")
|
||||
if source_path and source_path == source:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
self._logger.warning("Failed to check for duplicates", exc_info=True)
|
||||
return False
|
||||
|
||||
async def start_batch_import(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
items: List[Dict[str, str]],
|
||||
tags: Optional[List[str]] = None,
|
||||
skip_no_metadata: bool = False,
|
||||
skip_duplicates: bool = False,
|
||||
) -> str:
|
||||
operation_id = str(uuid.uuid4())
|
||||
|
||||
import_items = []
|
||||
for idx, item in enumerate(items):
|
||||
source = item.get("source", "")
|
||||
item_type_str = item.get("type", "url")
|
||||
|
||||
if item_type_str == "url" or source.startswith(("http://", "https://")):
|
||||
item_type = ImportItemType.URL
|
||||
else:
|
||||
item_type = ImportItemType.LOCAL_PATH
|
||||
|
||||
batch_import_item = BatchImportItem(
|
||||
id=f"{operation_id}_{idx}",
|
||||
source=source,
|
||||
item_type=item_type,
|
||||
)
|
||||
import_items.append(batch_import_item)
|
||||
|
||||
progress = BatchImportProgress(
|
||||
operation_id=operation_id,
|
||||
total=len(import_items),
|
||||
items=import_items,
|
||||
tags=tags or [],
|
||||
skip_no_metadata=skip_no_metadata,
|
||||
skip_duplicates=skip_duplicates,
|
||||
)
|
||||
|
||||
self._active_operations[operation_id] = progress
|
||||
self._cancellation_flags[operation_id] = False
|
||||
|
||||
asyncio.create_task(
|
||||
self._run_batch_import(
|
||||
operation_id=operation_id,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
)
|
||||
)
|
||||
|
||||
return operation_id
|
||||
|
||||
async def start_directory_import(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
directory: str,
|
||||
recursive: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
skip_no_metadata: bool = False,
|
||||
skip_duplicates: bool = False,
|
||||
) -> str:
|
||||
image_paths = await self._discover_images(directory, recursive)
|
||||
|
||||
items = [{"source": path, "type": "local_path"} for path in image_paths]
|
||||
|
||||
return await self.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=items,
|
||||
tags=tags,
|
||||
skip_no_metadata=skip_no_metadata,
|
||||
skip_duplicates=skip_duplicates,
|
||||
)
|
||||
|
||||
async def _discover_images(
|
||||
self,
|
||||
directory: str,
|
||||
recursive: bool = True,
|
||||
) -> List[str]:
|
||||
if not os.path.isdir(directory):
|
||||
raise RecipeValidationError(f"Directory not found: {directory}")
|
||||
|
||||
image_paths: List[str] = []
|
||||
|
||||
if recursive:
|
||||
for root, _, files in os.walk(directory):
|
||||
for filename in files:
|
||||
if self._is_supported_image(filename):
|
||||
image_paths.append(os.path.join(root, filename))
|
||||
else:
|
||||
for filename in os.listdir(directory):
|
||||
filepath = os.path.join(directory, filename)
|
||||
if os.path.isfile(filepath) and self._is_supported_image(filename):
|
||||
image_paths.append(filepath)
|
||||
|
||||
return sorted(image_paths)
|
||||
|
||||
def _is_supported_image(self, filename: str) -> bool:
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
return ext in self.SUPPORTED_EXTENSIONS
|
||||
|
||||
async def _run_batch_import(
|
||||
self,
|
||||
*,
|
||||
operation_id: str,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
) -> None:
|
||||
progress = self._active_operations.get(operation_id)
|
||||
if not progress:
|
||||
return
|
||||
|
||||
progress.status = "running"
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
self._concurrency_controller = AdaptiveConcurrencyController()
|
||||
|
||||
async def process_item(item: BatchImportItem) -> None:
|
||||
if self._cancellation_flags.get(operation_id, False):
|
||||
return
|
||||
|
||||
progress.current_item = (
|
||||
os.path.basename(item.source)
|
||||
if item.item_type == ImportItemType.LOCAL_PATH
|
||||
else item.source[:50]
|
||||
)
|
||||
item.status = ImportStatus.PROCESSING
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await self._import_single_item(
|
||||
item=item,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
tags=progress.tags,
|
||||
skip_no_metadata=progress.skip_no_metadata,
|
||||
skip_duplicates=progress.skip_duplicates,
|
||||
semaphore=self._concurrency_controller.get_semaphore(),
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
item.duration = duration
|
||||
self._concurrency_controller.record_result(
|
||||
duration, result.get("success", False)
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
item.status = ImportStatus.SUCCESS
|
||||
item.recipe_name = result.get("recipe_name")
|
||||
item.recipe_id = result.get("recipe_id")
|
||||
progress.success += 1
|
||||
elif result.get("skipped"):
|
||||
item.status = ImportStatus.SKIPPED
|
||||
item.error_message = result.get("error")
|
||||
progress.skipped += 1
|
||||
else:
|
||||
item.status = ImportStatus.FAILED
|
||||
item.error_message = result.get("error")
|
||||
progress.failed += 1
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error importing {item.source}: {e}")
|
||||
item.status = ImportStatus.FAILED
|
||||
item.error_message = str(e)
|
||||
item.duration = time.time() - start_time
|
||||
progress.failed += 1
|
||||
self._concurrency_controller.record_result(item.duration, False)
|
||||
|
||||
progress.completed += 1
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
tasks = [process_item(item) for item in progress.items]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
if self._cancellation_flags.get(operation_id, False):
|
||||
progress.status = "cancelled"
|
||||
else:
|
||||
progress.status = "completed"
|
||||
|
||||
progress.finished_at = time.time()
|
||||
progress.current_item = ""
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
await asyncio.sleep(5)
|
||||
self._cleanup_operation(operation_id)
|
||||
|
||||
async def _import_single_item(
|
||||
self,
|
||||
*,
|
||||
item: BatchImportItem,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
tags: List[str],
|
||||
skip_no_metadata: bool,
|
||||
skip_duplicates: bool,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> Dict[str, Any]:
|
||||
async with semaphore:
|
||||
recipe_scanner = recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return {"success": False, "error": "Recipe scanner unavailable"}
|
||||
|
||||
try:
|
||||
if item.item_type == ImportItemType.URL:
|
||||
if not self._validate_url(item.source):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid URL format: {item.source}",
|
||||
}
|
||||
|
||||
if skip_duplicates:
|
||||
if self._is_duplicate_source(
|
||||
item.source, item.item_type, recipe_scanner
|
||||
):
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": "Duplicate source URL",
|
||||
}
|
||||
|
||||
civitai_client = civitai_client_getter()
|
||||
analysis_result = await self._analysis_service.analyze_remote_image(
|
||||
url=item.source,
|
||||
recipe_scanner=recipe_scanner,
|
||||
civitai_client=civitai_client,
|
||||
)
|
||||
else:
|
||||
if not self._validate_local_path(item.source):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid or unsafe path: {item.source}",
|
||||
}
|
||||
|
||||
if not os.path.exists(item.source):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File not found: {item.source}",
|
||||
}
|
||||
|
||||
if skip_duplicates:
|
||||
if self._is_duplicate_source(
|
||||
item.source, item.item_type, recipe_scanner
|
||||
):
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": "Duplicate source path",
|
||||
}
|
||||
|
||||
analysis_result = await self._analysis_service.analyze_local_image(
|
||||
file_path=item.source,
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
|
||||
payload = analysis_result.payload
|
||||
|
||||
if payload.get("error"):
|
||||
if skip_no_metadata and "No metadata" in payload.get("error", ""):
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": payload["error"],
|
||||
}
|
||||
return {"success": False, "error": payload["error"]}
|
||||
|
||||
loras = payload.get("loras", [])
|
||||
if not loras:
|
||||
if skip_no_metadata:
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": "No LoRAs found in image",
|
||||
}
|
||||
# When skip_no_metadata is False, allow importing images without LoRAs
|
||||
# Continue with empty loras list
|
||||
|
||||
recipe_name = self._generate_recipe_name(item, payload)
|
||||
all_tags = list(set(tags + (payload.get("tags", []) or [])))
|
||||
|
||||
metadata = {
|
||||
"base_model": payload.get("base_model", ""),
|
||||
"loras": loras,
|
||||
"gen_params": payload.get("gen_params", {}),
|
||||
"source_path": item.source,
|
||||
}
|
||||
|
||||
if payload.get("checkpoint"):
|
||||
metadata["checkpoint"] = payload["checkpoint"]
|
||||
|
||||
nsfw = payload.get("preview_nsfw_level")
|
||||
if isinstance(nsfw, int) and nsfw > 0:
|
||||
metadata["preview_nsfw_level"] = nsfw
|
||||
|
||||
image_bytes = None
|
||||
image_base64 = payload.get("image_base64")
|
||||
|
||||
if item.item_type == ImportItemType.LOCAL_PATH:
|
||||
with open(item.source, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
image_base64 = None
|
||||
|
||||
save_result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=image_bytes,
|
||||
image_base64=image_base64,
|
||||
name=recipe_name,
|
||||
tags=all_tags,
|
||||
metadata=metadata,
|
||||
extension=payload.get("extension"),
|
||||
)
|
||||
|
||||
if save_result.status == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"recipe_name": recipe_name,
|
||||
"recipe_id": save_result.payload.get("id"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": save_result.payload.get(
|
||||
"error", "Failed to save recipe"
|
||||
),
|
||||
}
|
||||
|
||||
except RecipeValidationError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
except RecipeDownloadError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
except RecipeNotFoundError as e:
|
||||
return {"success": False, "skipped": True, "error": str(e)}
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"Unexpected error importing {item.source}: {e}", exc_info=True
|
||||
)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _generate_recipe_name(
|
||||
self, item: BatchImportItem, payload: Dict[str, Any]
|
||||
) -> str:
|
||||
if item.item_type == ImportItemType.LOCAL_PATH:
|
||||
base_name = os.path.splitext(os.path.basename(item.source))[0]
|
||||
return base_name[:100]
|
||||
else:
|
||||
loras = payload.get("loras", [])
|
||||
if loras:
|
||||
first_lora = loras[0].get("name", "Recipe")
|
||||
return f"Import - {first_lora}"[:100]
|
||||
return f"Imported Recipe {item.id[:8]}"
|
||||
|
||||
async def _broadcast_progress(self, progress: BatchImportProgress) -> None:
|
||||
await self._ws_manager.broadcast(
|
||||
{
|
||||
"type": "batch_import_progress",
|
||||
**progress.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
def _cleanup_operation(self, operation_id: str) -> None:
|
||||
if operation_id in self._cancellation_flags:
|
||||
del self._cancellation_flags[operation_id]
|
||||
@@ -58,6 +58,7 @@ class CacheEntryValidator:
|
||||
'preview_nsfw_level': (0, False),
|
||||
'notes': ('', False),
|
||||
'usage_tips': ('', False),
|
||||
'hash_status': ('completed', False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -90,13 +91,31 @@ class CacheEntryValidator:
|
||||
|
||||
errors: List[str] = []
|
||||
repaired = False
|
||||
|
||||
# If auto_repair is on, we work on a copy. If not, we still need a safe way to check fields.
|
||||
working_entry = dict(entry) if auto_repair else entry
|
||||
|
||||
# Determine effective hash_status for validation logic
|
||||
hash_status = entry.get('hash_status')
|
||||
if hash_status is None:
|
||||
if auto_repair:
|
||||
working_entry['hash_status'] = 'completed'
|
||||
repaired = True
|
||||
hash_status = 'completed'
|
||||
|
||||
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
|
||||
value = working_entry.get(field_name)
|
||||
# Get current value from the original entry to avoid side effects during validation
|
||||
value = entry.get(field_name)
|
||||
|
||||
# Check if field is missing or None
|
||||
if value is None:
|
||||
# Special case: sha256 can be None/empty if hash_status is pending
|
||||
if field_name == 'sha256' and hash_status == 'pending':
|
||||
if auto_repair:
|
||||
working_entry[field_name] = ''
|
||||
repaired = True
|
||||
continue
|
||||
|
||||
if is_required:
|
||||
errors.append(f"Required field '{field_name}' is missing or None")
|
||||
if auto_repair:
|
||||
@@ -107,6 +126,10 @@ class CacheEntryValidator:
|
||||
# Validate field type and value
|
||||
field_error = cls._validate_field(field_name, value, default_value)
|
||||
if field_error:
|
||||
# Special case: allow empty string for sha256 if pending
|
||||
if field_name == 'sha256' and hash_status == 'pending' and value == '':
|
||||
continue
|
||||
|
||||
errors.append(field_error)
|
||||
if auto_repair:
|
||||
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||
@@ -127,7 +150,7 @@ class CacheEntryValidator:
|
||||
# Special validation: sha256 must not be empty for required field
|
||||
# BUT allow empty sha256 when hash_status is pending (lazy hash calculation)
|
||||
sha256 = working_entry.get('sha256', '')
|
||||
hash_status = working_entry.get('hash_status', 'completed')
|
||||
# Use the effective hash_status we determined earlier
|
||||
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
|
||||
# Allow empty sha256 for lazy hash calculation (checkpoints)
|
||||
if hash_status != 'pending':
|
||||
@@ -144,8 +167,13 @@ class CacheEntryValidator:
|
||||
if isinstance(sha256, str):
|
||||
normalized_sha = sha256.lower().strip()
|
||||
if normalized_sha != sha256:
|
||||
working_entry['sha256'] = normalized_sha
|
||||
repaired = True
|
||||
if auto_repair:
|
||||
working_entry['sha256'] = normalized_sha
|
||||
repaired = True
|
||||
else:
|
||||
# If not auto-repairing, we don't consider case difference as a "critical error"
|
||||
# that invalidates the entry, but we also don't mark it repaired.
|
||||
pass
|
||||
|
||||
# Determine if entry is valid
|
||||
# Entry is valid if no critical required field errors remain after repair
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -13,22 +14,38 @@ from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
file_extensions = {
|
||||
".ckpt",
|
||||
".pt",
|
||||
".pt2",
|
||||
".bin",
|
||||
".pth",
|
||||
".safetensors",
|
||||
".pkl",
|
||||
".sft",
|
||||
".gguf",
|
||||
}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
hash_index=ModelHashIndex(),
|
||||
)
|
||||
if not hasattr(self, "_hash_calculation_lock"):
|
||||
self._hash_calculation_lock = asyncio.Lock()
|
||||
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
|
||||
|
||||
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
|
||||
async def _create_default_metadata(
|
||||
self, file_path: str
|
||||
) -> Optional[CheckpointMetadata]:
|
||||
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
||||
|
||||
|
||||
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
||||
scanning to improve startup performance. Hash will be calculated on-demand when
|
||||
fetching metadata from Civitai.
|
||||
@@ -38,13 +55,13 @@ class CheckpointScanner(ModelScanner):
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
dir_path = os.path.dirname(file_path)
|
||||
|
||||
|
||||
# Find preview image
|
||||
preview_url = find_preview_file(base_name, dir_path)
|
||||
|
||||
|
||||
# Create metadata WITHOUT calculating hash
|
||||
metadata = CheckpointMetadata(
|
||||
file_name=base_name,
|
||||
@@ -59,70 +76,133 @@ class CheckpointScanner(ModelScanner):
|
||||
modelDescription="",
|
||||
sub_type="checkpoint",
|
||||
from_civitai=False, # Mark as local model since no hash yet
|
||||
hash_status="pending" # Mark hash as pending
|
||||
hash_status="pending", # Mark hash as pending
|
||||
)
|
||||
|
||||
|
||||
# Save the created metadata
|
||||
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
|
||||
logger.error(
|
||||
f"Error creating default checkpoint metadata for {file_path}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint on-demand.
|
||||
|
||||
"""Calculate hash for a checkpoint on-demand with per-file singleflight.
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
|
||||
|
||||
Returns:
|
||||
SHA256 hash string, or None if calculation failed
|
||||
"""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found for hash calculation: {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if (
|
||||
metadata is not None
|
||||
and metadata.hash_status == "completed"
|
||||
and metadata.sha256
|
||||
):
|
||||
return metadata.sha256
|
||||
|
||||
async with self._hash_calculation_lock:
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if (
|
||||
metadata is not None
|
||||
and metadata.hash_status == "completed"
|
||||
and metadata.sha256
|
||||
):
|
||||
return metadata.sha256
|
||||
|
||||
task = self._hash_calculation_tasks.get(real_path)
|
||||
if task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_hash_calculation_task(file_path, real_path)
|
||||
)
|
||||
self._hash_calculation_tasks[real_path] = task
|
||||
|
||||
return await asyncio.shield(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _run_hash_calculation_task(
|
||||
self, file_path: str, real_path: str
|
||||
) -> Optional[str]:
|
||||
"""Run a hash calculation task and remove it from the in-flight map."""
|
||||
try:
|
||||
return await self._calculate_hash_for_model_uncached(file_path, real_path)
|
||||
finally:
|
||||
task = asyncio.current_task()
|
||||
async with self._hash_calculation_lock:
|
||||
if self._hash_calculation_tasks.get(real_path) is task:
|
||||
del self._hash_calculation_tasks[real_path]
|
||||
|
||||
async def _calculate_hash_for_model_uncached(
|
||||
self, file_path: str, real_path: str
|
||||
) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint without checking in-flight tasks."""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
# Load current metadata
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata is None:
|
||||
logger.error(f"No metadata found for {file_path}")
|
||||
return None
|
||||
|
||||
if should_skip:
|
||||
logger.error(f"Invalid metadata found for {file_path}")
|
||||
return None
|
||||
created_metadata = await self._create_default_metadata(file_path)
|
||||
if created_metadata is None:
|
||||
logger.error(f"No metadata found for {file_path}")
|
||||
return None
|
||||
metadata = created_metadata
|
||||
|
||||
# Check if hash is already calculated
|
||||
if metadata.hash_status == "completed" and metadata.sha256:
|
||||
return metadata.sha256
|
||||
|
||||
|
||||
# Update status to calculating
|
||||
metadata.hash_status = "calculating"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
# Calculate hash
|
||||
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
||||
sha256 = await calculate_sha256(real_path)
|
||||
|
||||
|
||||
# Update metadata with hash
|
||||
metadata.sha256 = sha256
|
||||
metadata.hash_status = "completed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
# Update hash index
|
||||
self._hash_index.add_entry(sha256.lower(), file_path)
|
||||
|
||||
|
||||
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
||||
return sha256
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
# Update status to failed
|
||||
try:
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata:
|
||||
metadata.hash_status = "failed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
@@ -130,43 +210,46 @@ class CheckpointScanner(ModelScanner):
|
||||
pass
|
||||
return None
|
||||
|
||||
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
|
||||
async def calculate_all_pending_hashes(
|
||||
self, progress_callback=None
|
||||
) -> Dict[str, int]:
|
||||
"""Calculate hashes for all checkpoints with pending hash status.
|
||||
|
||||
|
||||
If cache is not initialized, scans filesystem directly for metadata files
|
||||
with hash_status != 'completed'.
|
||||
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(progress, total, current_file)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'completed', 'failed', 'total' counts
|
||||
"""
|
||||
# Try to get from cache first
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
|
||||
if cache and cache.raw_data:
|
||||
# Use cache if available
|
||||
pending_models = [
|
||||
item for item in cache.raw_data
|
||||
if item.get('hash_status') != 'completed' or not item.get('sha256')
|
||||
item
|
||||
for item in cache.raw_data
|
||||
if item.get("hash_status") != "completed" or not item.get("sha256")
|
||||
]
|
||||
else:
|
||||
# Cache not initialized, scan filesystem directly
|
||||
pending_models = await self._find_pending_models_from_filesystem()
|
||||
|
||||
|
||||
if not pending_models:
|
||||
return {'completed': 0, 'failed': 0, 'total': 0}
|
||||
|
||||
return {"completed": 0, "failed": 0, "total": 0}
|
||||
|
||||
total = len(pending_models)
|
||||
completed = 0
|
||||
failed = 0
|
||||
|
||||
|
||||
for i, model_data in enumerate(pending_models):
|
||||
file_path = model_data.get('file_path')
|
||||
file_path = model_data.get("file_path")
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
sha256 = await self.calculate_hash_for_model(file_path)
|
||||
if sha256:
|
||||
@@ -176,77 +259,102 @@ class CheckpointScanner(ModelScanner):
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
failed += 1
|
||||
|
||||
|
||||
if progress_callback:
|
||||
try:
|
||||
await progress_callback(i + 1, total, file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
'completed': completed,
|
||||
'failed': failed,
|
||||
'total': total
|
||||
}
|
||||
|
||||
|
||||
return {"completed": completed, "failed": failed, "total": total}
|
||||
|
||||
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
||||
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
||||
pending_models = []
|
||||
|
||||
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
|
||||
for dirpath, _dirnames, filenames in os.walk(root_path):
|
||||
for filename in filenames:
|
||||
if not filename.endswith('.metadata.json'):
|
||||
if not filename.endswith(".metadata.json"):
|
||||
continue
|
||||
|
||||
|
||||
metadata_path = os.path.join(dirpath, filename)
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# Check if hash is pending
|
||||
hash_status = data.get('hash_status', 'completed')
|
||||
sha256 = data.get('sha256', '')
|
||||
|
||||
if hash_status != 'completed' or not sha256:
|
||||
hash_status = data.get("hash_status", "completed")
|
||||
sha256 = data.get("sha256", "")
|
||||
|
||||
if hash_status != "completed" or not sha256:
|
||||
# Find corresponding model file
|
||||
model_name = filename.replace('.metadata.json', '')
|
||||
model_name = filename.replace(".metadata.json", "")
|
||||
model_path = None
|
||||
|
||||
|
||||
# Look for model file with matching name
|
||||
for ext in self.file_extensions:
|
||||
potential_path = os.path.join(dirpath, model_name + ext)
|
||||
if os.path.exists(potential_path):
|
||||
model_path = potential_path
|
||||
break
|
||||
|
||||
|
||||
if model_path:
|
||||
pending_models.append({
|
||||
'file_path': model_path.replace(os.sep, '/'),
|
||||
'hash_status': hash_status,
|
||||
'sha256': sha256,
|
||||
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
|
||||
})
|
||||
pending_models.append(
|
||||
{
|
||||
"file_path": model_path.replace(os.sep, "/"),
|
||||
"hash_status": hash_status,
|
||||
"sha256": sha256,
|
||||
**{
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if k
|
||||
not in [
|
||||
"file_path",
|
||||
"hash_status",
|
||||
"sha256",
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
|
||||
logger.debug(
|
||||
f"Error reading metadata file {metadata_path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
return pending_models
|
||||
|
||||
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
"""Resolve the sub-type based on the root path."""
|
||||
"""Resolve the sub-type based on the root path.
|
||||
|
||||
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
|
||||
"""
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
# Check standard ComfyUI checkpoint paths
|
||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||
return "checkpoint"
|
||||
|
||||
# Check extra checkpoint paths
|
||||
if (
|
||||
config.extra_checkpoints_roots
|
||||
and root_path in config.extra_checkpoints_roots
|
||||
):
|
||||
return "checkpoint"
|
||||
|
||||
# Check standard ComfyUI unet paths
|
||||
if config.unet_roots and root_path in config.unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
# Check extra unet paths
|
||||
if config.extra_unet_roots and root_path in config.extra_unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
|
||||
@@ -20,20 +21,37 @@ class CheckpointService(BaseModelService):
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
async def format_response(self, checkpoint_data: Dict) -> Optional[Dict]:
|
||||
"""Format Checkpoint data for API response.
|
||||
|
||||
Returns None when the entry is missing critical fields (corrupted cache
|
||||
row), so the handler layer can filter it out. See issue #730.
|
||||
"""
|
||||
# Guard against corrupted cache entries missing critical fields
|
||||
file_path = checkpoint_data.get("file_path")
|
||||
if not file_path or not isinstance(file_path, str):
|
||||
logger.warning(
|
||||
"Skipping corrupted checkpoint entry (missing file_path): %s",
|
||||
checkpoint_data.get("file_name", "<unknown>"),
|
||||
)
|
||||
return None
|
||||
|
||||
# Get sub_type from cache entry (new canonical field)
|
||||
sub_type = checkpoint_data.get("sub_type", "checkpoint")
|
||||
|
||||
|
||||
file_name = checkpoint_data.get("file_name") or ""
|
||||
model_name = checkpoint_data.get("model_name") or file_name
|
||||
folder = checkpoint_data.get("folder") or ""
|
||||
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"model_name": model_name,
|
||||
"file_name": file_name,
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"folder": folder,
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_path": file_path.replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
@@ -42,9 +60,13 @@ class CheckpointService(BaseModelService):
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"sub_type": sub_type,
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"exclude": bool(checkpoint_data.get("exclude", False)),
|
||||
"update_available": bool(checkpoint_data.get("update_available", False)),
|
||||
"skip_metadata_refresh": bool(checkpoint_data.get("skip_metadata_refresh", False)),
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True),
|
||||
"auto_tags": checkpoint_data.get("auto_tags") or extract_auto_tags(checkpoint_data),
|
||||
"version_count": checkpoint_data.get("version_count"),
|
||||
"hf_url": checkpoint_data.get("hf_url", ""),
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
@@ -186,6 +186,22 @@ class CivArchiveClient:
|
||||
if "metadata" in file_data:
|
||||
transformed["metadata"] = file_data["metadata"]
|
||||
|
||||
# Infer metadata.format from filename extension
|
||||
name = transformed.get("name")
|
||||
if name and isinstance(name, str):
|
||||
lower_name = name.lower()
|
||||
if lower_name.endswith(".safetensors"):
|
||||
inferred_format = "SafeTensor"
|
||||
elif lower_name.endswith(".ckpt"):
|
||||
inferred_format = "PickleTensor"
|
||||
else:
|
||||
inferred_format = None
|
||||
if inferred_format:
|
||||
if "metadata" not in transformed:
|
||||
transformed["metadata"] = {}
|
||||
if isinstance(transformed["metadata"], dict):
|
||||
transformed["metadata"].setdefault("format", inferred_format)
|
||||
|
||||
if file_data.get("modelVersionId") is not None:
|
||||
transformed["modelVersionId"] = file_data.get("modelVersionId")
|
||||
elif file_data.get("model_version_id") is not None:
|
||||
@@ -213,6 +229,20 @@ class CivArchiveClient:
|
||||
for file_data in candidates:
|
||||
if isinstance(file_data, dict):
|
||||
transformed_files.append(self._transform_file_entry(file_data))
|
||||
|
||||
# Sort: .safetensors first, .ckpt second, others last
|
||||
# so the backend fallback (no file_params) prefers safetensors
|
||||
def _sort_key(f: Dict) -> int:
|
||||
fname = f.get("name") or ""
|
||||
if isinstance(fname, str):
|
||||
lower = fname.lower()
|
||||
if lower.endswith(".safetensors"):
|
||||
return 0
|
||||
elif lower.endswith(".ckpt"):
|
||||
return 1
|
||||
return 2
|
||||
|
||||
transformed_files.sort(key=_sort_key)
|
||||
return transformed_files
|
||||
|
||||
def _transform_version(
|
||||
@@ -297,7 +327,7 @@ class CivArchiveClient:
|
||||
if resolved:
|
||||
return resolved, None
|
||||
|
||||
logger.error("Error fetching version of CivArchive model by hash %s", model_hash[:10])
|
||||
logger.debug("Error fetching version of CivArchive model by hash %s", model_hash[:10])
|
||||
return None, "No version data found"
|
||||
|
||||
except RateLimitError:
|
||||
@@ -387,7 +417,7 @@ class CivArchiveClient:
|
||||
|
||||
if version_id is not None:
|
||||
raw_id = version_data.get("id")
|
||||
if raw_id != version_id:
|
||||
if raw_id is not None and str(raw_id) != str(version_id):
|
||||
logger.warning(
|
||||
"Requested version %s doesn't match default version %s for model %s",
|
||||
version_id,
|
||||
|
||||
438
py/services/civitai_base_model_service.py
Normal file
438
py/services/civitai_base_model_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from ..utils.constants import SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS
|
||||
from .downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CivitaiBaseModelService:
|
||||
"""Service for fetching and managing Civitai base models.
|
||||
|
||||
This service provides:
|
||||
- Fetching base models from Civitai API
|
||||
- Caching with TTL (7 days default)
|
||||
- Merging hardcoded and remote base models
|
||||
- Generating abbreviations for new/unknown models
|
||||
"""
|
||||
|
||||
_instance: Optional[CivitaiBaseModelService] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
# Default TTL for cache in seconds (7 days)
|
||||
DEFAULT_CACHE_TTL = 7 * 24 * 60 * 60
|
||||
|
||||
# Civitai API endpoint for enums
|
||||
CIVITAI_ENUMS_URL = "https://civitai.red/api/v1/enums"
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> CivitaiBaseModelService:
|
||||
"""Get singleton instance of the service."""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the service."""
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# Cache storage
|
||||
self._cache: Optional[Dict[str, Any]] = None
|
||||
self._cache_timestamp: Optional[datetime] = None
|
||||
self._cache_ttl = self.DEFAULT_CACHE_TTL
|
||||
|
||||
# Hardcoded models for fallback
|
||||
self._hardcoded_models = set(SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS)
|
||||
|
||||
logger.info("CivitaiBaseModelService initialized")
|
||||
|
||||
async def get_base_models(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""Get merged base models (hardcoded + remote).
|
||||
|
||||
Args:
|
||||
force_refresh: If True, fetch from API regardless of cache state.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- models: List of merged base model names
|
||||
- source: 'cache', 'api', or 'fallback'
|
||||
- last_updated: ISO timestamp of last successful API fetch
|
||||
- hardcoded_count: Number of hardcoded models
|
||||
- remote_count: Number of remote models
|
||||
- merged_count: Total unique models
|
||||
"""
|
||||
# Check if cache is valid
|
||||
if not force_refresh and self._is_cache_valid():
|
||||
logger.debug("Returning cached base models")
|
||||
return self._build_response("cache")
|
||||
|
||||
# Try to fetch from API
|
||||
try:
|
||||
remote_models = await self._fetch_from_civitai()
|
||||
if remote_models:
|
||||
self._update_cache(remote_models)
|
||||
return self._build_response("api")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch base models from Civitai: {e}")
|
||||
|
||||
# Fallback to hardcoded models
|
||||
return self._build_response("fallback")
|
||||
|
||||
async def refresh_cache(self) -> Dict[str, Any]:
|
||||
"""Force refresh the cache from Civitai API.
|
||||
|
||||
Returns:
|
||||
Response dict same as get_base_models()
|
||||
"""
|
||||
return await self.get_base_models(force_refresh=True)
|
||||
|
||||
def get_cache_status(self) -> Dict[str, Any]:
|
||||
"""Get current cache status.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- has_cache: Whether cache exists
|
||||
- last_updated: ISO timestamp or None
|
||||
- is_expired: Whether cache is expired
|
||||
- ttl_seconds: TTL in seconds
|
||||
- age_seconds: Age of cache in seconds (if exists)
|
||||
"""
|
||||
if self._cache is None or self._cache_timestamp is None:
|
||||
return {
|
||||
"has_cache": False,
|
||||
"last_updated": None,
|
||||
"is_expired": True,
|
||||
"ttl_seconds": self._cache_ttl,
|
||||
"age_seconds": None,
|
||||
}
|
||||
|
||||
age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds()
|
||||
return {
|
||||
"has_cache": True,
|
||||
"last_updated": self._cache_timestamp.isoformat(),
|
||||
"is_expired": age > self._cache_ttl,
|
||||
"ttl_seconds": self._cache_ttl,
|
||||
"age_seconds": int(age),
|
||||
}
|
||||
|
||||
def generate_abbreviation(self, model_name: str) -> str:
|
||||
"""Generate abbreviation for a base model name.
|
||||
|
||||
Algorithm:
|
||||
1. Extract version patterns (e.g., "2.5" from "Wan Video 2.5")
|
||||
2. Extract main acronym (e.g., "SD" from "SD 1.5")
|
||||
3. Handle special cases (Flux, Wan, etc.)
|
||||
4. Fallback to first letters of words (max 4 chars)
|
||||
|
||||
Args:
|
||||
model_name: Full base model name
|
||||
|
||||
Returns:
|
||||
Generated abbreviation (max 4 characters)
|
||||
"""
|
||||
if not model_name or not isinstance(model_name, str):
|
||||
return "OTH"
|
||||
|
||||
name = model_name.strip()
|
||||
if not name:
|
||||
return "OTH"
|
||||
|
||||
# Check if it's already in hardcoded abbreviations
|
||||
# This is a simplified check - in practice you'd have a mapping
|
||||
lower_name = name.lower()
|
||||
|
||||
# Special cases
|
||||
special_cases = {
|
||||
"sd 1.4": "SD1",
|
||||
"sd 1.5": "SD1",
|
||||
"sd 1.5 lcm": "SD1",
|
||||
"sd 1.5 hyper": "SD1",
|
||||
"sd 2.0": "SD2",
|
||||
"sd 2.1": "SD2",
|
||||
"sd 3": "SD3",
|
||||
"sd 3.5": "SD3",
|
||||
"sd 3.5 medium": "SD3",
|
||||
"sd 3.5 large": "SD3",
|
||||
"sd 3.5 large turbo": "SD3",
|
||||
"sdxl 1.0": "XL",
|
||||
"sdxl lightning": "XL",
|
||||
"sdxl hyper": "XL",
|
||||
"flux.1 d": "F1D",
|
||||
"flux.1 s": "F1S",
|
||||
"flux.1 krea": "F1KR",
|
||||
"flux.1 kontext": "F1KX",
|
||||
"flux.2 d": "F2D",
|
||||
"flux.2 klein 9b": "FK9",
|
||||
"flux.2 klein 9b-base": "FK9B",
|
||||
"flux.2 klein 4b": "FK4",
|
||||
"flux.2 klein 4b-base": "FK4B",
|
||||
"auraflow": "AF",
|
||||
"chroma": "CHR",
|
||||
"pixart a": "PXA",
|
||||
"pixart e": "PXE",
|
||||
"hunyuan 1": "HY",
|
||||
"hunyuan video": "HYV",
|
||||
"lumina": "L",
|
||||
"kolors": "KLR",
|
||||
"noobai": "NAI",
|
||||
"illustrious": "IL",
|
||||
"pony": "PONY",
|
||||
"pony v7": "PNY7",
|
||||
"hidream": "HID",
|
||||
"qwen": "QWEN",
|
||||
"zimageturbo": "ZIT",
|
||||
"zimagebase": "ZIB",
|
||||
"anima": "ANI",
|
||||
"ernie": "ERNI",
|
||||
"ernie turbo": "ETRB",
|
||||
"nucleus": "NUCL",
|
||||
"krea 2": "KR2",
|
||||
"svd": "SVD",
|
||||
"ltxv": "LTXV",
|
||||
"ltxv2": "LTV2",
|
||||
"ltxv 2.3": "LTX",
|
||||
"cogvideox": "CVX",
|
||||
"mochi": "MCHI",
|
||||
"wan video": "WAN",
|
||||
"wan video 1.3b t2v": "WAN",
|
||||
"wan video 14b t2v": "WAN",
|
||||
"wan video 14b i2v 480p": "WAN",
|
||||
"wan video 14b i2v 720p": "WAN",
|
||||
"wan video 2.2 ti2v-5b": "WAN",
|
||||
"wan video 2.2 t2v-a14b": "WAN",
|
||||
"wan video 2.2 i2v-a14b": "WAN",
|
||||
"wan video 2.5 t2v": "WAN",
|
||||
"wan video 2.5 i2v": "WAN",
|
||||
}
|
||||
|
||||
if lower_name in special_cases:
|
||||
return special_cases[lower_name]
|
||||
|
||||
# Try to extract acronym from version pattern
|
||||
# e.g., "Model Name 2.5" -> "MN25"
|
||||
version_match = re.search(r"(\d+(?:\.\d+)?)", name)
|
||||
version = version_match.group(1) if version_match else ""
|
||||
|
||||
# Remove version and common words
|
||||
words = re.sub(r"\d+(?:\.\d+)?", "", name)
|
||||
words = re.sub(
|
||||
r"\b(model|video|diffusion|checkpoint|textualinversion)\b",
|
||||
"",
|
||||
words,
|
||||
flags=re.I,
|
||||
)
|
||||
words = words.strip()
|
||||
|
||||
# Get first letters of remaining words
|
||||
tokens = re.findall(r"[A-Za-z]+", words)
|
||||
if tokens:
|
||||
# Build abbreviation from first letters
|
||||
abbrev = "".join(token[0].upper() for token in tokens)
|
||||
# Add version if present
|
||||
if version:
|
||||
# Clean version (remove dots for abbreviation)
|
||||
version_clean = version.replace(".", "")
|
||||
abbrev = abbrev[: 4 - len(version_clean)] + version_clean
|
||||
return abbrev[:4]
|
||||
|
||||
# Final fallback: just take first 4 alphanumeric chars
|
||||
alphanumeric = re.sub(r"[^A-Za-z0-9]", "", name)
|
||||
if alphanumeric:
|
||||
return alphanumeric[:4].upper()
|
||||
|
||||
return "OTH"
|
||||
|
||||
async def _fetch_from_civitai(self) -> Optional[Set[str]]:
|
||||
"""Fetch base models from Civitai API.
|
||||
|
||||
Returns:
|
||||
Set of base model names, or None if failed
|
||||
"""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
"GET",
|
||||
self.CIVITAI_ENUMS_URL,
|
||||
use_auth=False, # enums endpoint doesn't require auth
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch enums from Civitai: {result}")
|
||||
return None
|
||||
|
||||
if isinstance(result, str):
|
||||
data = json.loads(result)
|
||||
else:
|
||||
data = result
|
||||
|
||||
# Extract base models from response
|
||||
base_models = set()
|
||||
|
||||
# Use ActiveBaseModel if available (recommended active models)
|
||||
if "ActiveBaseModel" in data:
|
||||
base_models.update(data["ActiveBaseModel"])
|
||||
logger.info(f"Fetched {len(base_models)} models from ActiveBaseModel")
|
||||
# Fallback to full BaseModel list
|
||||
elif "BaseModel" in data:
|
||||
base_models.update(data["BaseModel"])
|
||||
logger.info(f"Fetched {len(base_models)} models from BaseModel")
|
||||
else:
|
||||
logger.warning("No base model data found in Civitai response")
|
||||
return None
|
||||
|
||||
return base_models
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching from Civitai: {e}")
|
||||
return None
|
||||
|
||||
def _update_cache(self, remote_models: Set[str]) -> None:
|
||||
"""Update internal cache with fetched models.
|
||||
|
||||
Args:
|
||||
remote_models: Set of base model names from API
|
||||
"""
|
||||
self._cache = {
|
||||
"remote_models": sorted(remote_models),
|
||||
"hardcoded_models": sorted(self._hardcoded_models),
|
||||
}
|
||||
self._cache_timestamp = datetime.now(timezone.utc)
|
||||
logger.info(f"Cache updated with {len(remote_models)} remote models")
|
||||
|
||||
def _is_cache_valid(self) -> bool:
|
||||
"""Check if current cache is valid (not expired).
|
||||
|
||||
Returns:
|
||||
True if cache exists and is not expired
|
||||
"""
|
||||
if self._cache is None or self._cache_timestamp is None:
|
||||
return False
|
||||
|
||||
age = (datetime.now(timezone.utc) - self._cache_timestamp).total_seconds()
|
||||
return age <= self._cache_ttl
|
||||
|
||||
def _build_response(self, source: str) -> Dict[str, Any]:
|
||||
"""Build response dictionary.
|
||||
|
||||
Args:
|
||||
source: 'cache', 'api', or 'fallback'
|
||||
|
||||
Returns:
|
||||
Response dictionary
|
||||
"""
|
||||
if source == "fallback" or self._cache is None:
|
||||
# Use only hardcoded models
|
||||
merged = sorted(self._hardcoded_models)
|
||||
return {
|
||||
"models": merged,
|
||||
"source": source,
|
||||
"last_updated": None,
|
||||
"hardcoded_count": len(self._hardcoded_models),
|
||||
"remote_count": 0,
|
||||
"merged_count": len(merged),
|
||||
}
|
||||
|
||||
# Merge hardcoded and remote models
|
||||
remote_set = set(self._cache.get("remote_models", []))
|
||||
merged = sorted(self._hardcoded_models | remote_set)
|
||||
|
||||
return {
|
||||
"models": merged,
|
||||
"source": source,
|
||||
"last_updated": self._cache_timestamp.isoformat()
|
||||
if self._cache_timestamp
|
||||
else None,
|
||||
"hardcoded_count": len(self._hardcoded_models),
|
||||
"remote_count": len(remote_set),
|
||||
"merged_count": len(merged),
|
||||
}
|
||||
|
||||
def get_model_categories(self) -> Dict[str, List[str]]:
|
||||
"""Get categorized base models.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category names to lists of model names
|
||||
"""
|
||||
# Define category patterns
|
||||
categories = {
|
||||
"Stable Diffusion 1.x": ["SD 1.4", "SD 1.5", "SD 1.5 LCM", "SD 1.5 Hyper"],
|
||||
"Stable Diffusion 2.x": ["SD 2.0", "SD 2.1"],
|
||||
"Stable Diffusion 3.x": [
|
||||
"SD 3",
|
||||
"SD 3.5",
|
||||
"SD 3.5 Medium",
|
||||
"SD 3.5 Large",
|
||||
"SD 3.5 Large Turbo",
|
||||
],
|
||||
"SDXL": ["SDXL 1.0", "SDXL Lightning", "SDXL Hyper"],
|
||||
"Flux Models": [
|
||||
"Flux.1 D",
|
||||
"Flux.1 S",
|
||||
"Flux.1 Krea",
|
||||
"Flux.1 Kontext",
|
||||
"Flux.2 D",
|
||||
"Flux.2 Klein 9B",
|
||||
"Flux.2 Klein 9B-base",
|
||||
"Flux.2 Klein 4B",
|
||||
"Flux.2 Klein 4B-base",
|
||||
],
|
||||
"Video Models": [
|
||||
"SVD",
|
||||
"LTXV",
|
||||
"LTXV2",
|
||||
"LTXV 2.3",
|
||||
"CogVideoX",
|
||||
"Mochi",
|
||||
"Hunyuan Video",
|
||||
"Wan Video",
|
||||
"Wan Video 1.3B t2v",
|
||||
"Wan Video 14B t2v",
|
||||
"Wan Video 14B i2v 480p",
|
||||
"Wan Video 14B i2v 720p",
|
||||
"Wan Video 2.2 TI2V-5B",
|
||||
"Wan Video 2.2 T2V-A14B",
|
||||
"Wan Video 2.2 I2V-A14B",
|
||||
"Wan Video 2.5 T2V",
|
||||
"Wan Video 2.5 I2V",
|
||||
],
|
||||
"Other Models": [
|
||||
"Illustrious",
|
||||
"Pony",
|
||||
"Pony V7",
|
||||
"HiDream",
|
||||
"Qwen",
|
||||
"AuraFlow",
|
||||
"Chroma",
|
||||
"ZImageTurbo",
|
||||
"ZImageBase",
|
||||
"PixArt a",
|
||||
"PixArt E",
|
||||
"Hunyuan 1",
|
||||
"Lumina",
|
||||
"Kolors",
|
||||
"NoobAI",
|
||||
"Anima",
|
||||
"Ernie",
|
||||
"Ernie Turbo",
|
||||
"Nucleus",
|
||||
"Krea 2",
|
||||
],
|
||||
}
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
# Convenience function for getting the singleton instance
|
||||
async def get_civitai_base_model_service() -> CivitaiBaseModelService:
|
||||
"""Get the singleton instance of CivitaiBaseModelService."""
|
||||
return await CivitaiBaseModelService.get_instance()
|
||||
@@ -2,7 +2,13 @@ import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||
from .connectivity_guard import (
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
is_expected_offline_error,
|
||||
is_offline_cooldown_error,
|
||||
)
|
||||
from .model_metadata_provider import (
|
||||
CivitaiModelMetadataProvider,
|
||||
ModelMetadataProviderManager,
|
||||
@@ -39,7 +45,18 @@ class CivitaiClient:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self.base_url = "https://civitai.com/api/v1"
|
||||
self.base_url = "https://civitai.red/api/v1"
|
||||
# In-memory cache to avoid redundant get_model_version_info calls
|
||||
# within the same import/scan flow. Only successful results are cached.
|
||||
# Uses OrderedDict with LRU eviction at MAX_CACHE_ENTRIES to prevent
|
||||
# unbounded growth in long-running server processes.
|
||||
self._version_info_cache: OrderedDict[
|
||||
str, Tuple[Optional[Dict], Optional[str]]
|
||||
] = OrderedDict()
|
||||
self._MAX_CACHE_ENTRIES = 500
|
||||
|
||||
def _build_image_info_url(self, image_id: str) -> str:
|
||||
return f"{self.base_url}/images?imageId={image_id}&nsfw=X&withMeta=true"
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
@@ -49,20 +66,57 @@ class CivitaiClient:
|
||||
use_auth: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[bool, Dict | str]:
|
||||
"""Wrapper around downloader.make_request that surfaces rate limits."""
|
||||
"""Wrapper around downloader.make_request that surfaces rate limits,
|
||||
with retry for transient server errors (5xx, Cloudflare 524, network flakiness)."""
|
||||
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
method,
|
||||
url,
|
||||
use_auth=use_auth,
|
||||
**kwargs,
|
||||
)
|
||||
if not success and isinstance(result, RateLimitError):
|
||||
if result.provider is None:
|
||||
result.provider = "civitai_api"
|
||||
raise result
|
||||
return success, result
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
method,
|
||||
url,
|
||||
use_auth=use_auth,
|
||||
**kwargs,
|
||||
)
|
||||
if success:
|
||||
return True, result
|
||||
|
||||
if isinstance(result, RateLimitError):
|
||||
if result.provider is None:
|
||||
result.provider = "civitai_api"
|
||||
raise result
|
||||
|
||||
if is_offline_cooldown_error(result):
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||
|
||||
# Transient server error — retry with exponential backoff
|
||||
if self._is_transient_server_error(str(result)):
|
||||
if attempt < max_retries - 1:
|
||||
wait = 2**attempt # 1s, 2s, 4s
|
||||
logger.info(
|
||||
"Transient error on %s %s, retrying in %ds "
|
||||
"(attempt %d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
wait,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
result,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
logger.warning(
|
||||
"All %d retries exhausted for %s %s: %s",
|
||||
max_retries,
|
||||
method,
|
||||
url,
|
||||
result,
|
||||
)
|
||||
return False, result
|
||||
|
||||
return False, result
|
||||
|
||||
return False, "Unexpected error in _make_request"
|
||||
|
||||
@staticmethod
|
||||
def _remove_comfy_metadata(model_version: Optional[Dict]) -> None:
|
||||
@@ -121,6 +175,8 @@ class CivitaiClient:
|
||||
)
|
||||
if not success:
|
||||
message = str(version)
|
||||
if is_expected_offline_error(message):
|
||||
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||
if "not found" in message.lower():
|
||||
return None, "Model not found"
|
||||
|
||||
@@ -161,6 +217,9 @@ class CivitaiClient:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
if is_expected_offline_error(str(e)):
|
||||
logger.debug("Preview download skipped due to offline state.")
|
||||
return False
|
||||
logger.error(f"Download Error: {str(e)}")
|
||||
return False
|
||||
|
||||
@@ -186,11 +245,36 @@ class CivitaiClient:
|
||||
|
||||
return _from_value(payload)
|
||||
|
||||
@staticmethod
|
||||
def _is_transient_server_error(message: str) -> bool:
|
||||
"""Return True when the message indicates a transient upstream failure.
|
||||
|
||||
Recognises Cloudflare 524, generic 5xx, and connectivity-level flakiness
|
||||
that should not be treated as a permanent failure.
|
||||
"""
|
||||
normalized = message.lower()
|
||||
if "status 5" in normalized or "status 524" in normalized:
|
||||
return True
|
||||
if any(
|
||||
keyword in normalized
|
||||
for keyword in (
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"temporary failure",
|
||||
"name resolution",
|
||||
"connection closed",
|
||||
)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model with local availability info"""
|
||||
try:
|
||||
success, result = await self._make_request(
|
||||
"GET", f"{self.base_url}/models/{model_id}", use_auth=True
|
||||
"GET",
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True,
|
||||
)
|
||||
if success:
|
||||
# Also return model type along with versions
|
||||
@@ -202,7 +286,17 @@ class CivitaiClient:
|
||||
message = self._extract_error_message(result)
|
||||
if message and "not found" in message.lower():
|
||||
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
||||
if is_expected_offline_error(message):
|
||||
logger.info("Civitai request skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||
return None
|
||||
if message:
|
||||
if self._is_transient_server_error(message):
|
||||
logger.info(
|
||||
"Transient server error for model %s: %s",
|
||||
model_id,
|
||||
message,
|
||||
)
|
||||
return None
|
||||
raise RuntimeError(message)
|
||||
return None
|
||||
except RateLimitError:
|
||||
@@ -237,7 +331,7 @@ class CivitaiClient:
|
||||
"GET",
|
||||
f"{self.base_url}/models",
|
||||
use_auth=True,
|
||||
params={"ids": query},
|
||||
params={"ids": query, "nsfw": "true"},
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
@@ -316,6 +410,25 @@ class CivitaiClient:
|
||||
return None
|
||||
|
||||
target_version = self._select_target_version(model_data, model_id, version_id)
|
||||
|
||||
# If modelVersions is empty (e.g. CivitAI cache lag for newly published
|
||||
# models) but a specific version_id is known, fall back to fetching the
|
||||
# version directly via the individual model-versions endpoint, then
|
||||
# enrich it with the model-level data we already have.
|
||||
if target_version is None and version_id is not None:
|
||||
logger.info(
|
||||
"modelVersions empty for model %s; falling back to direct "
|
||||
"version lookup for %s",
|
||||
model_id,
|
||||
version_id,
|
||||
)
|
||||
version = await self._fetch_version_by_id(version_id)
|
||||
if version:
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
return None
|
||||
|
||||
if target_version is None:
|
||||
return None
|
||||
|
||||
@@ -346,10 +459,14 @@ class CivitaiClient:
|
||||
|
||||
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
|
||||
success, data = await self._make_request(
|
||||
"GET", f"{self.base_url}/models/{model_id}", use_auth=True
|
||||
"GET",
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True,
|
||||
)
|
||||
if success:
|
||||
return data
|
||||
if is_expected_offline_error(data):
|
||||
return None
|
||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||
return None
|
||||
|
||||
@@ -358,10 +475,14 @@ class CivitaiClient:
|
||||
return None
|
||||
|
||||
success, version = await self._make_request(
|
||||
"GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True
|
||||
"GET",
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True,
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
if is_expected_offline_error(version):
|
||||
return None
|
||||
|
||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||
return None
|
||||
@@ -371,10 +492,14 @@ class CivitaiClient:
|
||||
return None
|
||||
|
||||
success, version = await self._make_request(
|
||||
"GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True
|
||||
"GET",
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True,
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
if is_expected_offline_error(version):
|
||||
return None
|
||||
|
||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||
return None
|
||||
@@ -450,20 +575,33 @@ class CivitaiClient:
|
||||
- The model version data or None if not found
|
||||
- An error message if there was an error, or None on success
|
||||
"""
|
||||
# In-memory cache avoids redundant API calls within the same
|
||||
# import/scan flow (e.g. _resolve_base_model_from_checkpoint
|
||||
# followed by _resolve_and_populate_checkpoint with the same id).
|
||||
if version_id in self._version_info_cache:
|
||||
logger.debug("Cache hit for model version info: %s", version_id)
|
||||
self._version_info_cache.move_to_end(version_id) # LRU bump
|
||||
return self._version_info_cache[version_id]
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
|
||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||
logger.debug("Resolving Civitai model version info: %s", url)
|
||||
success, result = await self._make_request("GET", url, use_auth=True)
|
||||
|
||||
if success:
|
||||
logger.debug(
|
||||
f"Successfully fetched model version info for: {version_id}"
|
||||
)
|
||||
logger.debug("Successfully fetched model version info for: %s", version_id)
|
||||
self._remove_comfy_metadata(result)
|
||||
self._version_info_cache[version_id] = (result, None)
|
||||
self._version_info_cache.move_to_end(version_id)
|
||||
# Evict oldest entry when over capacity
|
||||
if len(self._version_info_cache) > self._MAX_CACHE_ENTRIES:
|
||||
self._version_info_cache.popitem(last=False)
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if is_expected_offline_error(result):
|
||||
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
@@ -479,47 +617,149 @@ class CivitaiClient:
|
||||
logger.error(error_msg)
|
||||
return None, error_msg
|
||||
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
async def get_image_info(
|
||||
self, image_id: str, source_url: str | None = None
|
||||
) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
Args:
|
||||
image_id: The Civitai image ID
|
||||
source_url: Original image page URL. Accepted for caller compatibility;
|
||||
API requests always target ``civitai.red``.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The image data or None if not found
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||
|
||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||
requested_id = int(image_id)
|
||||
url = self._build_image_info_url(image_id)
|
||||
success, result = await self._make_request("GET", url, use_auth=True)
|
||||
|
||||
if success:
|
||||
if result and "items" in result and len(result["items"]) > 0:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return result["items"][0]
|
||||
logger.warning(f"No image found with ID: {image_id}")
|
||||
if not success:
|
||||
if is_expected_offline_error(result):
|
||||
return None
|
||||
if self._is_transient_server_error(str(result)):
|
||||
logger.info(
|
||||
"Transient server error fetching image info for ID %s: %s",
|
||||
image_id,
|
||||
result,
|
||||
)
|
||||
return None
|
||||
logger.error(
|
||||
"Failed to fetch image info for ID %s from civitai.red: %s",
|
||||
image_id,
|
||||
result,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
||||
if result and "items" in result and isinstance(result["items"], list):
|
||||
items = result["items"]
|
||||
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("id") == requested_id:
|
||||
logger.debug(
|
||||
"Successfully fetched image info for ID %s from civitai.red",
|
||||
image_id,
|
||||
)
|
||||
return item
|
||||
|
||||
returned_ids = [
|
||||
item.get("id")
|
||||
for item in items
|
||||
if isinstance(item, dict) and "id" in item
|
||||
]
|
||||
|
||||
logger.warning(
|
||||
"CivitAI API returned no matching image for requested ID %s from civitai.red. Returned %d item(s) with IDs: %s. This may indicate the image was deleted, hidden, or there is a database lag.",
|
||||
image_id,
|
||||
len(items),
|
||||
returned_ids,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning("No image found with ID: %s", image_id)
|
||||
return None
|
||||
except RateLimitError:
|
||||
raise
|
||||
except ValueError as e:
|
||||
error_msg = f"Invalid image ID format: {image_id}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
async def get_model_versions_by_hashes(
|
||||
self, hashes: List[str]
|
||||
) -> Optional[List[Dict]]:
|
||||
"""Fetch full version details for up to 100 SHA256 hashes via the batch endpoint.
|
||||
|
||||
Uses POST /api/v1/model-versions/by-hash which returns full version
|
||||
details including ``usageControl`` and ``earlyAccessEndsAt`` that are
|
||||
not available from the model-level API.
|
||||
|
||||
Args:
|
||||
hashes: List of SHA256 hashes (max 100 per batch; auto-split).
|
||||
|
||||
Returns:
|
||||
List of version dicts or None on failure.
|
||||
"""
|
||||
if not hashes:
|
||||
return []
|
||||
|
||||
BATCH_SIZE = 100
|
||||
all_versions: List[Dict] = []
|
||||
|
||||
for start in range(0, len(hashes), BATCH_SIZE):
|
||||
batch = hashes[start : start + BATCH_SIZE]
|
||||
try:
|
||||
success, result = await self._make_request(
|
||||
"POST",
|
||||
f"{self.base_url}/model-versions/by-hash",
|
||||
use_auth=True,
|
||||
json=batch,
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Batch by-hash request failed for %d hashes: %s",
|
||||
len(batch),
|
||||
result,
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(result, list):
|
||||
all_versions.extend(result)
|
||||
else:
|
||||
logger.debug(
|
||||
"Unexpected by-hash response type: %s", type(result)
|
||||
)
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"Error fetching model versions by hashes: %s", exc
|
||||
)
|
||||
|
||||
return all_versions if all_versions else None
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Fetch all models for a specific Civitai user."""
|
||||
if not username:
|
||||
return None
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}/models?username={username}"
|
||||
success, result = await self._make_request("GET", url, use_auth=True)
|
||||
success, result = await self._make_request(
|
||||
"GET",
|
||||
f"{self.base_url}/models",
|
||||
use_auth=True,
|
||||
params={"username": username, "nsfw": "true"},
|
||||
)
|
||||
|
||||
if not success:
|
||||
if is_expected_offline_error(result):
|
||||
logger.info("User model fetch skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||
return None
|
||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||
return None
|
||||
|
||||
|
||||
204
py/services/connectivity_guard.py
Normal file
204
py/services/connectivity_guard.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""In-memory connectivity guard to suppress repeated network retries when offline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OFFLINE_COOLDOWN_ERROR = "offline_cooldown"
|
||||
OFFLINE_FRIENDLY_MESSAGE = "Network offline, will retry automatically later"
|
||||
|
||||
|
||||
def is_offline_cooldown_error(value: Any) -> bool:
|
||||
"""Return True when a response payload represents guard short-circuit."""
|
||||
return isinstance(value, str) and value == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
|
||||
def is_expected_offline_error(value: Any) -> bool:
|
||||
"""Return True when payload is an expected offline-related result."""
|
||||
if is_offline_cooldown_error(value):
|
||||
return True
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
normalized = value.lower()
|
||||
return "network offline" in normalized or "offline" in normalized
|
||||
|
||||
|
||||
class ConnectivityGuard:
|
||||
"""Tracks network failures and gates outbound requests during cooldown."""
|
||||
|
||||
_instance: "ConnectivityGuard | None" = None
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "ConnectivityGuard":
|
||||
async with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
self._initialized = True
|
||||
self._default_destination = "__global__"
|
||||
self._destination_states: dict[str, _DestinationState] = {
|
||||
self._default_destination: _DestinationState()
|
||||
}
|
||||
self.base_backoff_seconds = 30
|
||||
self.max_backoff_seconds = 300
|
||||
self.failure_threshold = 3
|
||||
|
||||
@property
|
||||
def online(self) -> bool:
|
||||
return self._state_for_destination(None).online
|
||||
|
||||
@online.setter
|
||||
def online(self, value: bool) -> None:
|
||||
self._state_for_destination(None).online = value
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._state_for_destination(None).failure_count
|
||||
|
||||
@failure_count.setter
|
||||
def failure_count(self, value: int) -> None:
|
||||
self._state_for_destination(None).failure_count = value
|
||||
|
||||
@property
|
||||
def cooldown_until(self) -> datetime | None:
|
||||
return self._state_for_destination(None).cooldown_until
|
||||
|
||||
@cooldown_until.setter
|
||||
def cooldown_until(self, value: datetime | None) -> None:
|
||||
self._state_for_destination(None).cooldown_until = value
|
||||
|
||||
def _now(self) -> datetime:
|
||||
return datetime.now()
|
||||
|
||||
def _normalize_destination(self, destination: str | None) -> str:
|
||||
if destination is None or not destination.strip():
|
||||
return self._default_destination
|
||||
return destination.lower().strip()
|
||||
|
||||
def _state_for_destination(self, destination: str | None) -> "_DestinationState":
|
||||
destination_key = self._normalize_destination(destination)
|
||||
if destination_key not in self._destination_states:
|
||||
self._destination_states[destination_key] = _DestinationState()
|
||||
return self._destination_states[destination_key]
|
||||
|
||||
def in_cooldown(self, destination: str | None = None) -> bool:
|
||||
state = self._state_for_destination(destination)
|
||||
if state.cooldown_until is None:
|
||||
return False
|
||||
return self._now() < state.cooldown_until
|
||||
|
||||
def cooldown_remaining_seconds(self, destination: str | None = None) -> float:
|
||||
state = self._state_for_destination(destination)
|
||||
if state.cooldown_until is None:
|
||||
return 0.0
|
||||
return max(0.0, (state.cooldown_until - self._now()).total_seconds())
|
||||
|
||||
def should_block_request(self, destination: str | None = None) -> bool:
|
||||
return self.in_cooldown(destination)
|
||||
|
||||
def register_success(self, destination: str | None = None) -> None:
|
||||
destination_key = self._normalize_destination(destination)
|
||||
state = self._state_for_destination(destination_key)
|
||||
was_offline = (not state.online) or state.cooldown_until is not None
|
||||
state.online = True
|
||||
state.failure_count = 0
|
||||
state.cooldown_until = None
|
||||
if was_offline:
|
||||
logger.info(
|
||||
"Connectivity restored for destination '%s'; requests resumed.",
|
||||
destination_key,
|
||||
)
|
||||
|
||||
def register_network_failure(
|
||||
self, exc: Exception, destination: str | None = None
|
||||
) -> None:
|
||||
destination_key = self._normalize_destination(destination)
|
||||
state = self._state_for_destination(destination_key)
|
||||
state.online = False
|
||||
state.failure_count += 1
|
||||
|
||||
if state.failure_count < self.failure_threshold:
|
||||
logger.debug(
|
||||
"Network failure tracked for destination '%s' (%d/%d): %s",
|
||||
destination_key,
|
||||
state.failure_count,
|
||||
self.failure_threshold,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
retry_step = state.failure_count - self.failure_threshold
|
||||
backoff = min(
|
||||
self.max_backoff_seconds,
|
||||
self.base_backoff_seconds * (2**retry_step),
|
||||
)
|
||||
should_log_warning = not self.in_cooldown(destination_key)
|
||||
state.cooldown_until = self._now() + timedelta(seconds=backoff)
|
||||
|
||||
if should_log_warning:
|
||||
logger.warning(
|
||||
"Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.",
|
||||
destination_key,
|
||||
int(backoff),
|
||||
state.failure_count,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.",
|
||||
destination_key,
|
||||
state.failure_count,
|
||||
int(backoff),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_network_unreachable_error(exc: Exception) -> bool:
|
||||
"""Return whether the exception should count as connectivity failure."""
|
||||
if isinstance(exc, asyncio.CancelledError):
|
||||
return False
|
||||
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
asyncio.TimeoutError,
|
||||
TimeoutError,
|
||||
ConnectionRefusedError,
|
||||
socket.gaierror,
|
||||
aiohttp.ServerTimeoutError,
|
||||
aiohttp.ConnectionTimeoutError,
|
||||
aiohttp.ClientConnectorError,
|
||||
aiohttp.ClientConnectionError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
if isinstance(exc, OSError) and exc.errno in {
|
||||
errno.ENETUNREACH,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ETIMEDOUT,
|
||||
errno.ECONNREFUSED,
|
||||
}:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DestinationState:
|
||||
online: bool = True
|
||||
failure_count: int = 0
|
||||
cooldown_until: datetime | None = None
|
||||
@@ -7,11 +7,13 @@ with category filtering and enriched results including post counts.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_EMBEDDED_COMMAND_PATTERN = re.compile(r"\s/\w")
|
||||
class CustomWordsService:
|
||||
"""Service for autocomplete via TagFTSIndex.
|
||||
|
||||
@@ -49,6 +51,7 @@ class CustomWordsService:
|
||||
if self._tag_index is None:
|
||||
try:
|
||||
from .tag_fts_index import get_tag_fts_index
|
||||
|
||||
self._tag_index = get_tag_fts_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize TagFTSIndex: {e}")
|
||||
@@ -59,14 +62,16 @@ class CustomWordsService:
|
||||
self,
|
||||
search_term: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
categories: Optional[List[int]] = None,
|
||||
enriched: bool = False
|
||||
enriched: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search tags using TagFTSIndex with category filtering.
|
||||
|
||||
Args:
|
||||
search_term: The search term to match against.
|
||||
limit: Maximum number of results to return.
|
||||
offset: Number of results to skip.
|
||||
categories: Optional list of category IDs to filter by.
|
||||
enriched: If True, always return enriched results with category
|
||||
and post_count (default behavior now).
|
||||
@@ -74,10 +79,28 @@ class CustomWordsService:
|
||||
Returns:
|
||||
List of dicts with tag_name, category, and post_count.
|
||||
"""
|
||||
normalized_search = search_term.strip()
|
||||
if not normalized_search:
|
||||
return []
|
||||
|
||||
# Prompt widgets should only send the active token, but guard against
|
||||
# accidental full-prompt queries reaching the FTS path.
|
||||
if (
|
||||
"__" in normalized_search
|
||||
or "," in normalized_search
|
||||
or ">" in normalized_search
|
||||
or "\n" in normalized_search
|
||||
or "\r" in normalized_search
|
||||
or _EMBEDDED_COMMAND_PATTERN.search(normalized_search)
|
||||
):
|
||||
logger.debug("Skipping prompt-like custom words query: %s", normalized_search)
|
||||
return []
|
||||
|
||||
tag_index = self._get_tag_index()
|
||||
if tag_index is not None:
|
||||
results = tag_index.search(search_term, categories=categories, limit=limit)
|
||||
return results
|
||||
return tag_index.search(
|
||||
normalized_search, categories=categories, limit=limit, offset=offset
|
||||
)
|
||||
|
||||
logger.debug("TagFTSIndex not available, returning empty results")
|
||||
return []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user