mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
Compare commits
1362 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f09224152a | ||
|
|
df93670598 | ||
|
|
073fb3a94a | ||
|
|
53c4165d82 | ||
|
|
8cd4550189 | ||
|
|
2b2e4fefab | ||
|
|
5f93648297 | ||
|
|
8a628f0bd0 | ||
|
|
b67c8598d6 | ||
|
|
0254c9d0e9 | ||
|
|
ecb512995c | ||
|
|
f8b9fa9b20 | ||
|
|
5d4917c8d9 | ||
|
|
a50309c22e | ||
|
|
f5020e081f | ||
|
|
3c0bfcb226 | ||
|
|
9198a23ba9 | ||
|
|
02bac7edfb | ||
|
|
ea1d1a49c9 | ||
|
|
9a789f8f08 | ||
|
|
1971881537 | ||
|
|
4eb46a8d3e | ||
|
|
36f28b3c65 | ||
|
|
2452cc4df1 | ||
|
|
eda1ce9743 | ||
|
|
e24621a0af | ||
|
|
7173a2b9d6 | ||
|
|
d540b21aac | ||
|
|
9952721e76 | ||
|
|
26e4895807 | ||
|
|
c533a8e7bf | ||
|
|
dc820a456f | ||
|
|
07721af87c | ||
|
|
5093c30c06 | ||
|
|
8c77080ae6 | ||
|
|
bcf72c6bcc | ||
|
|
3849f7eef9 | ||
|
|
7eced1e3e9 | ||
|
|
51b5261f40 | ||
|
|
963f6b1383 | ||
|
|
b75baa1d1a | ||
|
|
6d95e93378 | ||
|
|
7117e0c33e | ||
|
|
d261474f3a | ||
|
|
c09d67d2e4 | ||
|
|
1427dc8e38 | ||
|
|
77a7b90dc7 | ||
|
|
e9d55fe146 | ||
|
|
57f369a6de | ||
|
|
059ebeead7 | ||
|
|
831a9da9d7 | ||
|
|
6000e08640 | ||
|
|
3edc65c106 | ||
|
|
655157434e | ||
|
|
3661b11b70 | ||
|
|
0e73db0669 | ||
|
|
8158441a92 | ||
|
|
5600471093 | ||
|
|
354cf03bbc | ||
|
|
645b7c247d | ||
|
|
5f25a29303 | ||
|
|
906d00106d | ||
|
|
7850131969 | ||
|
|
3d5ec4a9f1 | ||
|
|
1cdbb9a851 | ||
|
|
e224be4b88 | ||
|
|
b9d3a4afce | ||
|
|
aa4aa1a613 | ||
|
|
cc8e1c5049 | ||
|
|
41e649415a | ||
|
|
c8f770a86b | ||
|
|
29bb85359e | ||
|
|
4557da8b63 | ||
|
|
09b75de25b | ||
|
|
415fc5720c | ||
|
|
4dd8ce778e | ||
|
|
f81ff2efe9 | ||
|
|
837bb17b08 | ||
|
|
5ee93a27ee | ||
|
|
2e6aa5fe9f | ||
|
|
c14e066f8f | ||
|
|
c09100c22e | ||
|
|
839ed3bda3 | ||
|
|
1f627774c1 | ||
|
|
3b842355c2 | ||
|
|
dd27411ebf | ||
|
|
388ff7f5b4 | ||
|
|
f76343f389 | ||
|
|
ce5a1ae3d0 | ||
|
|
1d40d7400f | ||
|
|
1bb5d0b072 | ||
|
|
c3932538e1 | ||
|
|
a68141adf4 | ||
|
|
fb8ba4c076 | ||
|
|
4ed3bd9039 | ||
|
|
ba6e2eadba | ||
|
|
1c16392367 | ||
|
|
035ad4b473 | ||
|
|
a7ee883227 | ||
|
|
ddf9e33961 | ||
|
|
4301b3455f | ||
|
|
3d6bb432c4 | ||
|
|
6c03aa1430 | ||
|
|
5376fd8724 | ||
|
|
6dea9a76bc | ||
|
|
d73903e82e | ||
|
|
4862419b61 | ||
|
|
e6e7df7454 | ||
|
|
30f9e3e2ec | ||
|
|
707d0cb8a4 | ||
|
|
56ea7594ce | ||
|
|
389e46c251 | ||
|
|
6db17e682a | ||
|
|
94e0308a12 | ||
|
|
1f9f821576 | ||
|
|
57933dfba6 | ||
|
|
c50bee7757 | ||
|
|
4e3ee843f9 | ||
|
|
7e40f6fcb9 | ||
|
|
7976956b6b | ||
|
|
adce5293d5 | ||
|
|
c2db5eb6df | ||
|
|
f958ecdf18 | ||
|
|
ef0bcc6cf1 | ||
|
|
285428ad3a | ||
|
|
ee18cff3d9 | ||
|
|
1be3235564 | ||
|
|
a92883509a | ||
|
|
ce42d83ce9 | ||
|
|
077cf7b574 | ||
|
|
b99d78bda6 | ||
|
|
39586f4a20 | ||
|
|
4ef750b206 | ||
|
|
9d3d93823d | ||
|
|
45c1113b72 | ||
|
|
e10717dcda | ||
|
|
315ab6f70b | ||
|
|
cf4d654c4b | ||
|
|
569c829709 | ||
|
|
de05b59f29 | ||
|
|
70a282a6c0 | ||
|
|
b10bcf7e78 | ||
|
|
5fb10263f3 | ||
|
|
9e76c9783e | ||
|
|
7770976513 | ||
|
|
dc1f7ab6fe | ||
|
|
32b1d6c561 | ||
|
|
5264e49f2a | ||
|
|
ce3adaf831 | ||
|
|
e2f3e57f5c | ||
|
|
5c2349ff42 | ||
|
|
50eee8c373 | ||
|
|
f89b792535 | ||
|
|
6d0ea2841c | ||
|
|
98678a8698 | ||
|
|
5326fa2970 | ||
|
|
90547670a2 | ||
|
|
4753206c52 | ||
|
|
613aa3b1c3 | ||
|
|
a6b704d4b4 | ||
|
|
227d06c736 | ||
|
|
8508763831 | ||
|
|
136d3153fa | ||
|
|
49bdf77040 | ||
|
|
f4dcd89835 | ||
|
|
139e915711 | ||
|
|
22eda58074 | ||
|
|
fb91cf4df2 | ||
|
|
e0332571da | ||
|
|
2d4bc47746 | ||
|
|
38e766484e | ||
|
|
b5ee4a6408 | ||
|
|
7892df21ec | ||
|
|
188fe407b6 | ||
|
|
600afdcd92 | ||
|
|
994fa4bd43 | ||
|
|
51098f2829 | ||
|
|
795b9e8418 | ||
|
|
9ca2b9dd56 | ||
|
|
d77b6d78b7 | ||
|
|
427e7a36d5 | ||
|
|
c90306cc9b | ||
|
|
5fe0660c64 | ||
|
|
2abb5bf122 | ||
|
|
bb65527469 | ||
|
|
d9a6db3359 | ||
|
|
58cafdb713 | ||
|
|
0594e278b6 | ||
|
|
807425f12a | ||
|
|
aa4b1ccc25 | ||
|
|
58255ec28b | ||
|
|
d62b84693d | ||
|
|
df75c7e68d | ||
|
|
c5c7fdf54f | ||
|
|
49e0deeff3 | ||
|
|
0c20701bef | ||
|
|
faa26651dd | ||
|
|
2eae8a7729 | ||
|
|
dde2b2a960 | ||
|
|
4a9089d3dd | ||
|
|
3244a5f1a1 | ||
|
|
449c1e9d10 | ||
|
|
d0aa916683 | ||
|
|
13433f8cd2 | ||
|
|
8d336320c0 | ||
|
|
d945c58d51 | ||
|
|
acaf122346 | ||
|
|
713759b411 | ||
|
|
c5175bb870 | ||
|
|
e63ef8d031 | ||
|
|
e043537241 | ||
|
|
46126f9950 | ||
|
|
f4eb916914 | ||
|
|
49b9b7a5ea | ||
|
|
9b1a9ee071 | ||
|
|
0b8f137a1b | ||
|
|
6148a12301 | ||
|
|
fadbf21b4f | ||
|
|
c38a06937d | ||
|
|
1a34403b0e | ||
|
|
e4d58d0f60 | ||
|
|
4e4ea85cc3 | ||
|
|
f7a856349a | ||
|
|
15edd7a42c | ||
|
|
46243a236d | ||
|
|
6f382e587a | ||
|
|
bf3d706bf4 | ||
|
|
cdf21e813c | ||
|
|
10f5588e4a | ||
|
|
0ecbdf6f39 | ||
|
|
61101a7ad0 | ||
|
|
6d9be814a5 | ||
|
|
52bf93e430 | ||
|
|
00fade756c | ||
|
|
3c0feb23ba | ||
|
|
3627840fe9 | ||
|
|
bbdc1bba87 | ||
|
|
21a1bc1a01 | ||
|
|
0968698804 | ||
|
|
a5b2e9b0bf | ||
|
|
5a6ff444b9 | ||
|
|
3bb240d3c1 | ||
|
|
ee0d241c75 | ||
|
|
321ff72953 | ||
|
|
412f1e62a1 | ||
|
|
8901b32a55 | ||
|
|
8ab6cc72ad | ||
|
|
52e671638b | ||
|
|
a3070f8d82 | ||
|
|
3fde474583 | ||
|
|
1454991d6d | ||
|
|
4398851bb9 | ||
|
|
5173aa6c20 | ||
|
|
3d98572a62 | ||
|
|
c48095d9c6 | ||
|
|
1e4d1b8f15 | ||
|
|
8c037465ba | ||
|
|
055c1ca0d4 | ||
|
|
27370df93a | ||
|
|
60d23aa238 | ||
|
|
5e441d9c4f | ||
|
|
eb76468280 | ||
|
|
01bbaa31a8 | ||
|
|
bddf023dc4 | ||
|
|
8e69a247ed | ||
|
|
97141b01e1 | ||
|
|
acf610ddff | ||
|
|
a9a6f66035 | ||
|
|
0040863a03 | ||
|
|
4ab86b4ae2 | ||
|
|
b32b4b4042 | ||
|
|
4e552dcf3e | ||
|
|
8f4c02efdc | ||
|
|
b77c596f3a | ||
|
|
181f0b5626 | ||
|
|
480e5d966f | ||
|
|
e8636b949d | ||
|
|
8ea369db47 | ||
|
|
ec9b37eb53 | ||
|
|
b0847f6b87 | ||
|
|
84d10b1f3b | ||
|
|
4fdc97d062 | ||
|
|
5fe5e7ea54 | ||
|
|
7be1a2bd65 | ||
|
|
87842385c6 | ||
|
|
1dc189eb39 | ||
|
|
6120922204 | ||
|
|
ddb30dbb17 | ||
|
|
1e8bd88e28 | ||
|
|
c3a66ecf28 | ||
|
|
1f60160e8b | ||
|
|
7d560bf07a | ||
|
|
47da9949d9 | ||
|
|
68c0a5ba71 | ||
|
|
1aa81c803b | ||
|
|
8f5e134d3e | ||
|
|
ef03a2a917 | ||
|
|
e275968553 | ||
|
|
76d3aa2b5b | ||
|
|
c9a65c7347 | ||
|
|
f542ade628 | ||
|
|
d2c2bfbe6a | ||
|
|
2b6910bd55 | ||
|
|
b1dd733493 | ||
|
|
5dcf0a1e48 | ||
|
|
cf357b57fc | ||
|
|
4e1773833f | ||
|
|
8cf762ffd3 | ||
|
|
d997eaa429 | ||
|
|
8e51f0f19f | ||
|
|
f0e246b4ac | ||
|
|
a232997a79 | ||
|
|
08a449db99 | ||
|
|
0c023c9888 | ||
|
|
0ad92d00b3 | ||
|
|
a726cbea1e | ||
|
|
c53fa8692b | ||
|
|
3118f3b43c | ||
|
|
9199950b74 | ||
|
|
4c7e31687b | ||
|
|
75e207b520 | ||
|
|
631289b75e | ||
|
|
1b958d0a5d | ||
|
|
35fdf9020d | ||
|
|
45926b1dca | ||
|
|
686ba5024d | ||
|
|
cf375c7c86 | ||
|
|
5e53d76f44 | ||
|
|
7757f72859 | ||
|
|
c8cc584049 | ||
|
|
2cdd269bba | ||
|
|
d2d97ae5bb | ||
|
|
d08d77c555 | ||
|
|
92f8d2139a | ||
|
|
50f2c2dfe6 | ||
|
|
3539c453d3 | ||
|
|
1631122f95 | ||
|
|
8fcb979544 | ||
|
|
8a5af0b7f3 | ||
|
|
cb1f08d556 | ||
|
|
1150267765 | ||
|
|
5c1252548d | ||
|
|
3c7cdf5db8 | ||
|
|
9ac4203b1c | ||
|
|
d0800510db | ||
|
|
f8ba551cc4 | ||
|
|
413444500e | ||
|
|
e21d5835ec | ||
|
|
f2f354e478 | ||
|
|
b195d4569c | ||
|
|
3b77fed72d | ||
|
|
fc64e97f92 | ||
|
|
1da0434454 | ||
|
|
cf2fe40612 | ||
|
|
8f46433ff7 | ||
|
|
f3be3ae269 | ||
|
|
cfec5447d3 | ||
|
|
2d36b461cf | ||
|
|
5e23e4b13d | ||
|
|
badae2e8b3 | ||
|
|
9e64531de6 | ||
|
|
fdec8d283c | ||
|
|
9abedbf7cb | ||
|
|
66004c1cdc | ||
|
|
5b564cd8a3 | ||
|
|
2e79970e6e | ||
|
|
67c82ba6ea | ||
|
|
98425f37b8 | ||
|
|
9d22dd3465 | ||
|
|
837138db49 | ||
|
|
d43d992362 | ||
|
|
16b611cb7e | ||
|
|
8dde2d5e0d | ||
|
|
22b0b2bd24 | ||
|
|
056f727bfd | ||
|
|
0aa6c53c1f | ||
|
|
d9b0660611 | ||
|
|
d01666f4e2 | ||
|
|
51bee87cd0 | ||
|
|
3041b443e5 | ||
|
|
d95e6c939b | ||
|
|
fd38c63b35 | ||
|
|
b69c24ae14 | ||
|
|
65a0c00e33 | ||
|
|
b12a5ef133 | ||
|
|
9e1b92c26e | ||
|
|
3922aec36e | ||
|
|
41cca8e56d | ||
|
|
2d37a7341a | ||
|
|
40e3c6134c | ||
|
|
edddd47a1e | ||
|
|
4ea6f38645 | ||
|
|
40d998a026 | ||
|
|
3af8f151ac | ||
|
|
e066fa6873 | ||
|
|
6bd94269d4 | ||
|
|
c90edec18a | ||
|
|
cbb302614c | ||
|
|
c54611a11b | ||
|
|
88f249649a | ||
|
|
fe9fbdb93c | ||
|
|
28bc966b76 | ||
|
|
77bbf85b52 | ||
|
|
3b1990e97a | ||
|
|
375b5a49f3 | ||
|
|
392c157cb5 | ||
|
|
6f5bf4b582 | ||
|
|
2e3f48ebb7 | ||
|
|
e4a2c518bb | ||
|
|
f19fb68b4c | ||
|
|
9121c12a2c | ||
|
|
d0fe28cfe2 | ||
|
|
656e3e43be | ||
|
|
c2c1772371 | ||
|
|
88d5caf642 | ||
|
|
1684978693 | ||
|
|
8e4927600f | ||
|
|
4d72dc57e7 | ||
|
|
e7316b3389 | ||
|
|
e17b374606 | ||
|
|
141f83065f | ||
|
|
6381dbafc1 | ||
|
|
fc9db4510f | ||
|
|
66abf736c9 | ||
|
|
af713470c1 | ||
|
|
93a51d2bcb | ||
|
|
3f3e06de8a | ||
|
|
7315aac9d8 | ||
|
|
d933308a6f | ||
|
|
3baf93dcc5 | ||
|
|
6ba14bd8fe | ||
|
|
7499570766 | ||
|
|
003ee55a75 | ||
|
|
b0cc42ef1f | ||
|
|
23679ec3f5 | ||
|
|
da52e5b9dd | ||
|
|
c4e357793f | ||
|
|
6c3424029c | ||
|
|
dd9e6a5b69 | ||
|
|
095320ef72 | ||
|
|
35f7674bcd | ||
|
|
26b36c123d | ||
|
|
c85e694c1d | ||
|
|
ec05282db6 | ||
|
|
3d6f9b226f | ||
|
|
eda6df4a5d | ||
|
|
d504f89f6a | ||
|
|
14c468f2a2 | ||
|
|
2a99b0e46f | ||
|
|
ae8914f5c8 | ||
|
|
0c9f8971ce | ||
|
|
d7a75ea4e5 | ||
|
|
3ad8d8b17c | ||
|
|
39225dc204 | ||
|
|
4fb69f7d89 | ||
|
|
0890c6ad24 | ||
|
|
dd81809589 | ||
|
|
f0672beb46 | ||
|
|
cc5301e710 | ||
|
|
9d5ec43c4e | ||
|
|
6d41211b07 | ||
|
|
d58b61eed5 | ||
|
|
4b53d98bfc | ||
|
|
f51f354e48 | ||
|
|
59d027181d | ||
|
|
0d0988c090 | ||
|
|
dc2de50924 | ||
|
|
12c88835f2 | ||
|
|
6f4453aaf3 | ||
|
|
4b4b8fe3c1 | ||
|
|
49e7c2e9f5 | ||
|
|
4653c273e3 | ||
|
|
ae145de2f2 | ||
|
|
dde7cf71c6 | ||
|
|
219cd242db | ||
|
|
e5b712c082 | ||
|
|
4d2c60d59b | ||
|
|
1d2c1b114b | ||
|
|
2bde936d05 | ||
|
|
cd3e32bf4b | ||
|
|
454536d631 | ||
|
|
656f1755fd | ||
|
|
8aa76ce5c1 | ||
|
|
49fa37f00d | ||
|
|
9f83548cf3 | ||
|
|
6054d95e85 | ||
|
|
8c9bb35824 | ||
|
|
3eacf9558a | ||
|
|
fee37172b4 | ||
|
|
e128c80eb1 | ||
|
|
5cc735ed57 | ||
|
|
43fcce6361 | ||
|
|
49b7126278 | ||
|
|
679cfb5c69 | ||
|
|
50616bc680 | ||
|
|
aaad270822 | ||
|
|
bd10280736 | ||
|
|
d477050239 | ||
|
|
85f79cd8d1 | ||
|
|
613cd81152 | ||
|
|
e0aba6c49a | ||
|
|
d78bcf2494 | ||
|
|
f7cffd2eba | ||
|
|
0d0b91aa80 | ||
|
|
42872e6d2d | ||
|
|
b91f06405d | ||
|
|
dac4c688d6 | ||
|
|
097a68ad18 | ||
|
|
4a98710db0 | ||
|
|
d033a374dd | ||
|
|
6aa23fe36a | ||
|
|
3220cfb79c | ||
|
|
b92e7aa446 | ||
|
|
c3b9c73541 | ||
|
|
81c6672880 | ||
|
|
08baf884d3 | ||
|
|
1c4096f3d5 | ||
|
|
66a3f3f59a | ||
|
|
624df1328b | ||
|
|
c063854b51 | ||
|
|
8cf99dd928 | ||
|
|
c07e885725 | ||
|
|
21772feadd | ||
|
|
2d00cfdd31 | ||
|
|
49e03d658b | ||
|
|
fec85bcc08 | ||
|
|
0e93a6bcb0 | ||
|
|
7e20f738fb | ||
|
|
24090e6077 | ||
|
|
1022b07f64 | ||
|
|
4faf912c6f | ||
|
|
56e4b24b07 | ||
|
|
12295d2fdc | ||
|
|
6261f7d18d | ||
|
|
9e1a2e3bb7 | ||
|
|
40cbb2155c | ||
|
|
a8d7070832 | ||
|
|
ab7266f3a4 | ||
|
|
3053b13fcb | ||
|
|
f3544b3471 | ||
|
|
1610048974 | ||
|
|
fc6f1bf95b | ||
|
|
67b274c1b2 | ||
|
|
fb0d6b5641 | ||
|
|
d30fbeb286 | ||
|
|
46e430ebbb | ||
|
|
bc4cd45fcb | ||
|
|
bdc86ddf15 | ||
|
|
ded17c1479 | ||
|
|
933e2fc01d | ||
|
|
1cddeee264 | ||
|
|
183c000080 | ||
|
|
adf7b6d4b2 | ||
|
|
0566d50346 | ||
|
|
4275dc3003 | ||
|
|
30956aeefc | ||
|
|
64e1dd3dd6 | ||
|
|
0dc4b6f728 | ||
|
|
86074c87d7 | ||
|
|
6f9245df01 | ||
|
|
4540e47055 | ||
|
|
4bb8981e78 | ||
|
|
c49be91aa0 | ||
|
|
2b847039d4 | ||
|
|
1147725fd7 | ||
|
|
26891e12a4 | ||
|
|
2f7e44a76f | ||
|
|
9366d3d2d0 | ||
|
|
6b606a5cc8 | ||
|
|
e5339c178a | ||
|
|
1a76f74482 | ||
|
|
13f13eb095 | ||
|
|
125fdecd61 | ||
|
|
d05076d258 | ||
|
|
00b77581fc | ||
|
|
897787d17c | ||
|
|
d5a280cf2b | ||
|
|
a0c2d9b5ad | ||
|
|
e713bd1ca2 | ||
|
|
a3c28c1003 | ||
|
|
f4b7c9a138 | ||
|
|
6b860b5f29 | ||
|
|
37dfcd6abd | ||
|
|
bc2fca3a4f | ||
|
|
f8ef159656 | ||
|
|
b2b8a9d37e | ||
|
|
15ae4031b7 | ||
|
|
688976ce3b | ||
|
|
a548af01dc | ||
|
|
0dd52eceb3 | ||
|
|
b8c6cf4ac1 | ||
|
|
beb8ff1dd1 | ||
|
|
6a8f0867d9 | ||
|
|
51ad1c9a33 | ||
|
|
34872eb612 | ||
|
|
8b4e3128ff | ||
|
|
c66cbc800b | ||
|
|
21941521a0 | ||
|
|
0d33884052 | ||
|
|
415df49377 | ||
|
|
f5f45002c7 | ||
|
|
1edf7126bb | ||
|
|
a1a55a1002 | ||
|
|
45f5cb46bd | ||
|
|
1b5e608a27 | ||
|
|
a7df8ae15c | ||
|
|
47ce0d0fe2 | ||
|
|
b220e288d0 | ||
|
|
1fc8b45b68 | ||
|
|
62f06302f0 | ||
|
|
3e5cb223f3 | ||
|
|
4ee5b7481c | ||
|
|
e104b78c01 | ||
|
|
ba1ac58721 | ||
|
|
a4fbeb6295 | ||
|
|
68f8871403 | ||
|
|
6fd74952b7 | ||
|
|
1ea468cfc4 | ||
|
|
14721c265f | ||
|
|
821827a375 | ||
|
|
9ba3e2c204 | ||
|
|
d287883671 | ||
|
|
ead34818db | ||
|
|
a060010b96 | ||
|
|
76a92ac847 | ||
|
|
74bc490383 | ||
|
|
510d476323 | ||
|
|
1e7257fd53 | ||
|
|
4ff1f51b1c | ||
|
|
74507cef05 | ||
|
|
c23ab04d90 | ||
|
|
d50dde6cf6 | ||
|
|
fcb1fb39be | ||
|
|
b0ef74f802 | ||
|
|
f332aef41d | ||
|
|
1f91a3da8e | ||
|
|
16840c321d | ||
|
|
c109e392ad | ||
|
|
5e69671366 | ||
|
|
52d23d9b75 | ||
|
|
4c4e6d7a7b | ||
|
|
03b6e78705 | ||
|
|
24c01141d7 | ||
|
|
6dc2811af4 | ||
|
|
e6425dce32 | ||
|
|
95e2ff5f1e | ||
|
|
92ac487128 | ||
|
|
3250fa89cb | ||
|
|
7475de366b | ||
|
|
affb507b37 | ||
|
|
3320b80150 | ||
|
|
fb2b69b787 | ||
|
|
29a05f6533 | ||
|
|
9fa3fac973 | ||
|
|
904b0d104a | ||
|
|
1d31dae110 | ||
|
|
476ecb7423 | ||
|
|
4eb67cf6da | ||
|
|
a5a9f7ed83 | ||
|
|
c0b029e228 | ||
|
|
9bebcc9a4b | ||
|
|
ac7d23011c | ||
|
|
491e09b7b5 | ||
|
|
192bc237bf | ||
|
|
f041f4a114 | ||
|
|
2546580377 | ||
|
|
8fbf2ab56d | ||
|
|
ea727aad2e | ||
|
|
5520aecbba | ||
|
|
6b738a4769 | ||
|
|
903a8050b3 | ||
|
|
31b032429d | ||
|
|
2bcf341f04 | ||
|
|
ca6f45b359 | ||
|
|
2a67cec16b | ||
|
|
1800afe31b | ||
|
|
8c6311355d | ||
|
|
91801dff85 | ||
|
|
be594133f0 | ||
|
|
8a538d117e | ||
|
|
8d9118cbee | ||
|
|
b67464ea13 | ||
|
|
33334da0bb | ||
|
|
40ce2baa7b | ||
|
|
1134466cc0 | ||
|
|
92341111ad | ||
|
|
4956d6781f | ||
|
|
63562240c4 | ||
|
|
84d801cf14 | ||
|
|
b56fe4ca68 | ||
|
|
6c83c65e02 | ||
|
|
a83f020fcc | ||
|
|
7f9a3bf272 | ||
|
|
f80e266d02 | ||
|
|
7bef562541 | ||
|
|
b2428f607c | ||
|
|
8303196b57 | ||
|
|
987b8c8742 | ||
|
|
e60a579b85 | ||
|
|
be8edafed0 | ||
|
|
a258a18fa4 | ||
|
|
59010ca431 | ||
|
|
75f3764e6c | ||
|
|
867ffd1163 | ||
|
|
6acccbbb94 | ||
|
|
b2c4efab45 | ||
|
|
408a435b71 | ||
|
|
36d3cd93d5 | ||
|
|
b36fea002e | ||
|
|
52acbd954a | ||
|
|
f6709a55c3 | ||
|
|
7b374d747b | ||
|
|
fd480a9360 | ||
|
|
ec8b228867 | ||
|
|
401200050b | ||
|
|
29160bd6e5 | ||
|
|
3c9e402bc0 | ||
|
|
ff4d0f0208 | ||
|
|
f82908221c | ||
|
|
4246908f2e | ||
|
|
f64597afd2 | ||
|
|
975ff2672d | ||
|
|
e90ba31784 | ||
|
|
a4074c93bc | ||
|
|
7a8b7598c7 | ||
|
|
cd0d832f14 | ||
|
|
5b0becaaf2 | ||
|
|
9817bac2fe | ||
|
|
f6bd48cfcd | ||
|
|
01843b8f2b | ||
|
|
94ed81de5e | ||
|
|
0700b8f399 | ||
|
|
d62cff9841 | ||
|
|
083f4805b2 | ||
|
|
8e5bfd379e | ||
|
|
2366f143d8 | ||
|
|
e997f5bc1b | ||
|
|
842beec7cc | ||
|
|
d2268fc9e0 | ||
|
|
a98e26139f | ||
|
|
522a3ea88b | ||
|
|
d7949fbc30 | ||
|
|
6df083a1d5 | ||
|
|
4dc80e7f6e | ||
|
|
c2a8508513 | ||
|
|
159193ef43 | ||
|
|
1f37ffb105 | ||
|
|
919fed05c5 | ||
|
|
1814f83bee | ||
|
|
1823840456 | ||
|
|
623c28bfc3 | ||
|
|
3079131337 | ||
|
|
a34ade0120 | ||
|
|
e9ada70088 | ||
|
|
597cc48248 | ||
|
|
ec3f857ef1 | ||
|
|
383b4de539 | ||
|
|
1bf9326604 | ||
|
|
d9f5459d46 | ||
|
|
e45a1b1e19 | ||
|
|
331ad8f644 | ||
|
|
52fa88b04c | ||
|
|
8895a64d24 | ||
|
|
fdec535559 | ||
|
|
6c5559ae2d | ||
|
|
9f54622b17 | ||
|
|
03b6f4b378 | ||
|
|
af4cbe2332 | ||
|
|
141f72963a | ||
|
|
3d3c66e12f | ||
|
|
ee84571bdb | ||
|
|
6500936aad | ||
|
|
32d2b6c013 | ||
|
|
05df40977d | ||
|
|
5d7a1dcde5 | ||
|
|
9c45d9db6c | ||
|
|
ca692ed0f2 | ||
|
|
af499565d3 | ||
|
|
fe2d7e3a9e | ||
|
|
9f69822221 | ||
|
|
bb43f047c2 | ||
|
|
2356662492 | ||
|
|
1624a45093 | ||
|
|
dcb9983786 | ||
|
|
83d1828905 | ||
|
|
6a281cf3ee | ||
|
|
ed1cd39a6c | ||
|
|
dda19b3920 | ||
|
|
25139ca922 | ||
|
|
3cd57a582c | ||
|
|
d3903ac655 | ||
|
|
199e374318 | ||
|
|
8375c1413d | ||
|
|
9e268cf016 | ||
|
|
112b3abc26 | ||
|
|
a8331a2357 | ||
|
|
52e3ad08c1 | ||
|
|
8d01d04ef0 | ||
|
|
a141384907 | ||
|
|
b8aa7184bd | ||
|
|
e4195f874d | ||
|
|
d04deff5ca | ||
|
|
20ce0778a0 | ||
|
|
5a0b3470f1 | ||
|
|
a920921570 | ||
|
|
286f4ff384 | ||
|
|
71ddfafa98 | ||
|
|
b7e3e53697 | ||
|
|
16df548b77 | ||
|
|
425c33ae00 | ||
|
|
c9289ed2dc | ||
|
|
96517cbdef | ||
|
|
b03420faac | ||
|
|
65a1aa7ca2 | ||
|
|
3a92e8eaf9 | ||
|
|
a8dc50d64a | ||
|
|
3397cc7d8d | ||
|
|
c3e8131b24 | ||
|
|
f8ca8584ae | ||
|
|
3050bbe260 | ||
|
|
e1dda2795a | ||
|
|
6d8408e626 | ||
|
|
0906271aa9 | ||
|
|
4c33c9d256 | ||
|
|
fa9c78209f | ||
|
|
6678ec8a60 | ||
|
|
854e467c12 | ||
|
|
e6b94c7b21 | ||
|
|
2c6f9d8602 | ||
|
|
c74033b9c0 | ||
|
|
d2b21d27bb | ||
|
|
215272469f | ||
|
|
f7d05ab0f1 | ||
|
|
6f2ad2be77 | ||
|
|
66575c719a | ||
|
|
677a239d53 | ||
|
|
3b96bfe5af | ||
|
|
83be5cfa64 | ||
|
|
6b834c2362 | ||
|
|
7abfc49e08 | ||
|
|
65d5f50088 | ||
|
|
4f1f4ffe3d | ||
|
|
b0c2027a1c | ||
|
|
33c83358b0 | ||
|
|
31223f0526 | ||
|
|
92daadb92c | ||
|
|
fae2e274fd | ||
|
|
342a722991 | ||
|
|
65ec6aacb7 | ||
|
|
9387470c69 | ||
|
|
31f6edf8f0 | ||
|
|
487b062175 | ||
|
|
d8e13de096 | ||
|
|
e8a30088ef | ||
|
|
bf7b07ba74 | ||
|
|
28fe3e7b7a | ||
|
|
c0eff2bb5e | ||
|
|
848c1741fe | ||
|
|
1370b8e8c1 | ||
|
|
82a068e610 | ||
|
|
32f42bafaa | ||
|
|
4081b7f022 | ||
|
|
a5808193a6 | ||
|
|
854ca322c1 | ||
|
|
c1d9b5137a | ||
|
|
f33d5745b3 | ||
|
|
d89c2ca128 | ||
|
|
835584cc85 | ||
|
|
b2ffbe3a68 | ||
|
|
defcc79e6c | ||
|
|
c06d9f84f0 | ||
|
|
fe57a8e156 | ||
|
|
b77105795a | ||
|
|
e2df5fcf27 | ||
|
|
836a64e728 | ||
|
|
08ba0c9f42 | ||
|
|
6fcc6a5299 | ||
|
|
6dd58248c6 | ||
|
|
2786801b71 | ||
|
|
ea29cbeb7a | ||
|
|
3cf9121a8c | ||
|
|
381bd3938a | ||
|
|
e4ce384023 | ||
|
|
12d1857b13 | ||
|
|
0d9003dea4 | ||
|
|
1a3751acfa | ||
|
|
c5a3af2399 | ||
|
|
ea8a64fafc | ||
|
|
981e367bf1 | ||
|
|
a3d6e62035 | ||
|
|
7f205cdcc8 | ||
|
|
e587189880 | ||
|
|
206c1bd69f | ||
|
|
a7d9255c2c | ||
|
|
08265a85ec | ||
|
|
1ed5630464 | ||
|
|
c784615f11 | ||
|
|
26d51b1190 | ||
|
|
d83fad6abc | ||
|
|
692796db46 | ||
|
|
f15c6f33f9 | ||
|
|
dda9eb4d7c | ||
|
|
6f3aeb61e7 | ||
|
|
d6145e633f | ||
|
|
07014d98ce | ||
|
|
e8ccdabe6c | ||
|
|
cf9fd2d5c2 | ||
|
|
bf9aa9356b | ||
|
|
68d00ce289 | ||
|
|
5288021e4f | ||
|
|
4d38add291 | ||
|
|
804808da4a | ||
|
|
298a95432d | ||
|
|
a834fc4b30 | ||
|
|
2c6c9542dd | ||
|
|
a9a7f4c8ec | ||
|
|
ea9370443d | ||
|
|
c2e00b240e | ||
|
|
a2b81ea099 | ||
|
|
ee609e8eac | ||
|
|
e04ef671e9 | ||
|
|
0184dfd7eb | ||
|
|
eccfa0ca54 | ||
|
|
6d3feb4bef | ||
|
|
29d2b5ee4b | ||
|
|
c82fabb67f | ||
|
|
fcfc868e57 | ||
|
|
67b403f8ca | ||
|
|
de06c6b2f6 | ||
|
|
fa444dfb8a | ||
|
|
124002a472 | ||
|
|
0c883433c1 | ||
|
|
bcf3b2cf55 | ||
|
|
357c4e9c08 | ||
|
|
9edfc68e91 | ||
|
|
8c06cb3e80 | ||
|
|
144fa0a6d4 | ||
|
|
25d5a1541e | ||
|
|
a579d36389 | ||
|
|
d766dac341 | ||
|
|
b15ef1bbc6 | ||
|
|
3e52e00597 | ||
|
|
f749dd0d52 | ||
|
|
48a8a42108 | ||
|
|
db7f57a5a4 | ||
|
|
556381b983 | ||
|
|
158d7d5898 | ||
|
|
18844da95d | ||
|
|
7e0df4d718 | ||
|
|
0dbb76e8c8 | ||
|
|
f73b3422a6 | ||
|
|
bd95e802ec | ||
|
|
5de16a78c5 | ||
|
|
6f8e09fcde | ||
|
|
f54d480f03 | ||
|
|
e68b213fb3 | ||
|
|
132334d500 | ||
|
|
a6f04c6d7e | ||
|
|
854e8bf356 | ||
|
|
6ff883d2d3 | ||
|
|
849b97afba | ||
|
|
1bd2635864 | ||
|
|
79ab0f7b6c | ||
|
|
79011bd257 | ||
|
|
c692713ffb | ||
|
|
df9b554ce1 | ||
|
|
277a8e4682 | ||
|
|
acb52dba09 | ||
|
|
8f10765254 | ||
|
|
0653f59473 | ||
|
|
7a4b5a4667 | ||
|
|
49c4a4068b | ||
|
|
40ad590046 | ||
|
|
30374ae3e6 | ||
|
|
ab22d16bad | ||
|
|
971cd56a4a | ||
|
|
d7cb546c5f | ||
|
|
9d8b7344cd | ||
|
|
2d4f6ae7ce | ||
|
|
d9126807b0 | ||
|
|
cad5fb3fba | ||
|
|
afe23ad6b7 | ||
|
|
fc4327087b | ||
|
|
71762d788f | ||
|
|
6472e00fb0 | ||
|
|
4043846767 | ||
|
|
d3b2bc962c | ||
|
|
54f7b64821 | ||
|
|
82a2a6e669 | ||
|
|
6376d60af5 | ||
|
|
b1e2e3831f | ||
|
|
5de1c8aa82 | ||
|
|
63dc5c2bdb | ||
|
|
7f2d1670a0 | ||
|
|
53c8c337fc | ||
|
|
5b4ec1b2a2 | ||
|
|
64dd2ed141 | ||
|
|
eb57e04e95 | ||
|
|
ae905c8630 | ||
|
|
c157e794f0 | ||
|
|
ed9bae6f6a | ||
|
|
9fe1ce19ad | ||
|
|
6148236cbd | ||
|
|
2471eb518a | ||
|
|
8931b41c76 | ||
|
|
7f523f167d | ||
|
|
446b6d6158 | ||
|
|
2ee057e19b | ||
|
|
afc810f21f | ||
|
|
357052a903 | ||
|
|
39d6d8d04a | ||
|
|
888896c0c0 | ||
|
|
ceee482ecc | ||
|
|
d0ed1213d8 | ||
|
|
f6ef428008 | ||
|
|
e726c4f442 | ||
|
|
402318e586 | ||
|
|
b198cc2a6e | ||
|
|
c3dd4da11b | ||
|
|
ba2e42b06e | ||
|
|
fa0902dc74 | ||
|
|
8fcb6083dc | ||
|
|
1ef88140e3 | ||
|
|
aa34c4c84c | ||
|
|
32d12bb334 | ||
|
|
1b2a02cb1a | ||
|
|
2ff11a16c4 | ||
|
|
441af82dbd | ||
|
|
e09c09af6f | ||
|
|
3721fe226f | ||
|
|
8ace0e11cf | ||
|
|
5e249b0b59 | ||
|
|
4889955ecf | ||
|
|
d840fd53da | ||
|
|
a61819cdb3 | ||
|
|
e986fbb5fb | ||
|
|
8f4d575ec8 | ||
|
|
605a06317b | ||
|
|
a7304ccf47 | ||
|
|
374e2bd4b9 | ||
|
|
09a3246ddb | ||
|
|
a615603866 | ||
|
|
1ca05808e1 | ||
|
|
5febc2a805 | ||
|
|
3c047bee58 | ||
|
|
022c6c157a | ||
|
|
fa587d5678 | ||
|
|
afa5a42f5a | ||
|
|
71df8ba3e2 | ||
|
|
8764998e8c | ||
|
|
2cb4f3aac8 | ||
|
|
1ccaf33aac | ||
|
|
cb0a8e0413 | ||
|
|
8674168df4 | ||
|
|
2221653801 | ||
|
|
78bcdcef5d | ||
|
|
672fbe2ac0 | ||
|
|
56a5970b44 | ||
|
|
a66cef7cfe | ||
|
|
c0b1c2e099 | ||
|
|
9e553bb87b | ||
|
|
f966514bc7 | ||
|
|
dc0a49f96d | ||
|
|
65c783c024 | ||
|
|
6395836fbb | ||
|
|
a7207084ef | ||
|
|
27ef1f1e71 | ||
|
|
68fdb14cd6 | ||
|
|
c2af282a85 | ||
|
|
92d48335cb | ||
|
|
78cac2edc2 | ||
|
|
26d105c439 | ||
|
|
7fec107b98 | ||
|
|
eb01ad3af9 | ||
|
|
e0d9880b32 | ||
|
|
e81e96f0ab | ||
|
|
06d5bd259c | ||
|
|
14238b8d62 | ||
|
|
3b51886927 | ||
|
|
a295ff2e06 | ||
|
|
18cdaabf5e | ||
|
|
787e37b7c6 | ||
|
|
4e5c8b2dd0 | ||
|
|
d8ddacde38 | ||
|
|
bb1e42f0d3 | ||
|
|
923669c495 | ||
|
|
7a4139544c | ||
|
|
4d6ea0236b | ||
|
|
e872a06f22 | ||
|
|
647bda2160 | ||
|
|
c1e93d23f3 | ||
|
|
c96550cc68 | ||
|
|
b1015ecdc5 | ||
|
|
f1b928a037 | ||
|
|
16c312c90b | ||
|
|
110ffd0118 | ||
|
|
35ad872419 | ||
|
|
9b943cf2b8 | ||
|
|
9d1b357e64 | ||
|
|
9fc2fb4d17 | ||
|
|
641fa8a3d9 | ||
|
|
add9269706 | ||
|
|
1a01c4a344 | ||
|
|
b4e7feed06 | ||
|
|
4b96c650eb | ||
|
|
107aef3785 | ||
|
|
b49807824f | ||
|
|
e5ef2ef8b5 | ||
|
|
88779ed56c | ||
|
|
8b59fb6adc | ||
|
|
7945647b0b | ||
|
|
2d39b84806 | ||
|
|
e151a19fcf | ||
|
|
99d2ba26b9 | ||
|
|
396924f4cc | ||
|
|
7545312229 | ||
|
|
26f9779fbf | ||
|
|
0bd62eef3a | ||
|
|
e06d15f508 | ||
|
|
aa1ee96bc9 | ||
|
|
355c73512d | ||
|
|
0daf9d92ff | ||
|
|
37de26ce25 | ||
|
|
0eaef7e7a0 | ||
|
|
8063cee3cd | ||
|
|
cbb25b4ac0 | ||
|
|
c62206a157 | ||
|
|
09832141d0 | ||
|
|
bf8e121a10 | ||
|
|
68568073ec | ||
|
|
ec36524c35 | ||
|
|
67acd9fd2c | ||
|
|
f7be5c8d25 | ||
|
|
ceacac75e0 | ||
|
|
bae66f94e8 | ||
|
|
ddf132bd78 | ||
|
|
afb012029f | ||
|
|
651e14c8c3 | ||
|
|
e7c626eb5f | ||
|
|
a0b0d40a19 | ||
|
|
42e3ab9e27 | ||
|
|
6e5f333364 | ||
|
|
f33a9abe60 | ||
|
|
7f1bbdd615 | ||
|
|
d3bf8eaceb | ||
|
|
b9c9d602de | ||
|
|
b25fbd6e24 | ||
|
|
6052608a4e | ||
|
|
a073b82751 | ||
|
|
8250acdfb5 | ||
|
|
8e1f73a34e | ||
|
|
50704bc882 | ||
|
|
35d34e3513 | ||
|
|
ea834f3de6 | ||
|
|
11aedde72f | ||
|
|
488654abc8 | ||
|
|
da1be0dc65 | ||
|
|
d0c728a339 | ||
|
|
66c66c4d9b | ||
|
|
4882721387 | ||
|
|
06a8850c0c | ||
|
|
370aa06c67 | ||
|
|
c9fa0564e7 | ||
|
|
2ba7a0ceba | ||
|
|
276aedfbb9 | ||
|
|
c193c75674 | ||
|
|
a562ba3746 | ||
|
|
2fedd572ff | ||
|
|
db0b49c427 | ||
|
|
03a6f8111c | ||
|
|
925ad7b3e0 | ||
|
|
bf793d5b8b | ||
|
|
64a906ca5e | ||
|
|
99b36442bb | ||
|
|
3c5164d510 | ||
|
|
ec4b5a4d45 | ||
|
|
78e1901779 | ||
|
|
cb539314de | ||
|
|
c7627fe0de | ||
|
|
84bfad7ce5 | ||
|
|
3e06938b05 | ||
|
|
4f712fec14 | ||
|
|
c5c9659c76 | ||
|
|
d6e175c1f1 | ||
|
|
88088e1071 | ||
|
|
958ddbca86 | ||
|
|
6670fd28f4 | ||
|
|
1e59c31de3 | ||
|
|
c966dbbbbc | ||
|
|
af8f5ba04e | ||
|
|
b741ed0b3b | ||
|
|
01ba3c14f8 | ||
|
|
d13b1a83ad | ||
|
|
303477db70 | ||
|
|
311e89e9e7 | ||
|
|
8546cfe714 | ||
|
|
e6f4d84b9a | ||
|
|
ce7e422169 | ||
|
|
e5aec80984 | ||
|
|
6d97817390 | ||
|
|
d516f22159 | ||
|
|
e918c18ca2 | ||
|
|
5dd8d905fa | ||
|
|
1121d1ee6c | ||
|
|
4793f096af | ||
|
|
7b5b4ce082 | ||
|
|
fa08c9c3e4 | ||
|
|
d0d5eb956a | ||
|
|
969f949330 | ||
|
|
9169bbd04d | ||
|
|
99463ad01c | ||
|
|
f1d6b0feda | ||
|
|
e33da50278 | ||
|
|
4034eb3221 | ||
|
|
75a95f0109 | ||
|
|
92fdc16fe6 | ||
|
|
23fa2995c8 | ||
|
|
59aefdff77 | ||
|
|
e92ab9e3cc | ||
|
|
e3bf1f763c | ||
|
|
1c6e9d0b69 | ||
|
|
bfd4eb3e11 | ||
|
|
c9f902a8af | ||
|
|
0b67510ec9 | ||
|
|
b5cd320e8b | ||
|
|
deb25b4987 | ||
|
|
4612da264a | ||
|
|
59b67e1e10 | ||
|
|
5fad936b27 | ||
|
|
e376a45dea | ||
|
|
fd593bb61d | ||
|
|
71b97d5974 | ||
|
|
2b405ae164 | ||
|
|
2fe4736b69 | ||
|
|
184f8ca6cf | ||
|
|
1ff2019dde | ||
|
|
a3d8261686 | ||
|
|
7d0600976e | ||
|
|
e1e6e4f3dc | ||
|
|
fba2853773 | ||
|
|
48df7e1078 | ||
|
|
235dcd5fa6 | ||
|
|
2027db7411 | ||
|
|
611dd33c75 | ||
|
|
ec1c92a714 | ||
|
|
6ac78156ac | ||
|
|
e94b74e92d | ||
|
|
2bbec47f63 | ||
|
|
b5ddf4c953 | ||
|
|
44be75aeef | ||
|
|
2c03759b5d | ||
|
|
2e3da03723 | ||
|
|
6e96fbcda7 | ||
|
|
d1fd5b7f27 | ||
|
|
9dbcc105e7 | ||
|
|
5cd5a82ddc | ||
|
|
88c1892dc9 | ||
|
|
3c1b181675 | ||
|
|
6777dc16ca | ||
|
|
3833647dfe | ||
|
|
b6c47f0cce | ||
|
|
d308c7ac60 | ||
|
|
947c757aa5 | ||
|
|
5ee5bd7d36 | ||
|
|
d9c4ae92cd | ||
|
|
e1efff19f0 | ||
|
|
61f723a1f5 | ||
|
|
b32756932b | ||
|
|
cb5e64d26b | ||
|
|
f36febf10a | ||
|
|
26d9a9caa6 | ||
|
|
cb876cf77e | ||
|
|
4789711910 | ||
|
|
4064980505 | ||
|
|
f9b8f2d22c | ||
|
|
6a95aadc53 | ||
|
|
f9f08f082d | ||
|
|
0817901bef | ||
|
|
ac22172e53 | ||
|
|
fd87fbf31e | ||
|
|
554be0908f | ||
|
|
eaec4e5f13 | ||
|
|
0e7ba27a7d | ||
|
|
c551f5c23b | ||
|
|
5159657ae5 | ||
|
|
d35db7df72 | ||
|
|
2b5399c559 | ||
|
|
9e61bbbd8e | ||
|
|
7ce5857cd5 | ||
|
|
38fbae99fd | ||
|
|
b0a9d44b0c | ||
|
|
b4e22cd375 | ||
|
|
9bc92736a7 | ||
|
|
111b34d05c | ||
|
|
07d9599a2f | ||
|
|
d8194f211d | ||
|
|
51a6374c33 | ||
|
|
aa6c6035b6 | ||
|
|
44b4a7ffbb | ||
|
|
e5bb018d22 | ||
|
|
79b8a6536e | ||
|
|
3de31cd06a | ||
|
|
c579b54d40 | ||
|
|
0a52575e8b | ||
|
|
23c9a98f66 | ||
|
|
796fc33b5b | ||
|
|
dc4c11ddd2 | ||
|
|
d389e4d5d4 | ||
|
|
8cb78ad931 | ||
|
|
85f987d15c | ||
|
|
b12079e0f6 | ||
|
|
dcf5c6167a | ||
|
|
b395d3f487 | ||
|
|
37662cad10 | ||
|
|
aa1673063d | ||
|
|
f51f49eb60 | ||
|
|
54c9bac961 | ||
|
|
e70fd73bdd | ||
|
|
9bb9e7b64d | ||
|
|
f64c03543a | ||
|
|
51374de1a1 | ||
|
|
afcc12f263 | ||
|
|
88c5482366 | ||
|
|
bbf7295c32 | ||
|
|
ca5e23e68c | ||
|
|
eadb1487ae | ||
|
|
1faa70fc77 | ||
|
|
30d7c007de | ||
|
|
f54f6a4402 | ||
|
|
7b41cdec65 | ||
|
|
fb6a652a57 | ||
|
|
ea34d753c1 | ||
|
|
2bc46e708e | ||
|
|
96e3b5b7b3 | ||
|
|
fafbafa5e1 | ||
|
|
be8605d8c6 | ||
|
|
061660d47a | ||
|
|
2ed6dbb344 | ||
|
|
4766b45746 | ||
|
|
0734252e98 | ||
|
|
91b4827c1d | ||
|
|
df6d56ce66 | ||
|
|
f0203c96ab | ||
|
|
bccabe40c0 | ||
|
|
c2f599b4ff | ||
|
|
5fd069d70d | ||
|
|
32d34d1748 | ||
|
|
18eb605605 | ||
|
|
4fdc88e9e1 | ||
|
|
4c69d8d3a8 | ||
|
|
d4b2dd0ec1 | ||
|
|
181f78421b | ||
|
|
8ed38527d0 | ||
|
|
c4c926070d | ||
|
|
ed87411e0d | ||
|
|
4ec2a448ab | ||
|
|
73d01da94e | ||
|
|
df8e02157a | ||
|
|
6e513ed32a | ||
|
|
325ef6327d | ||
|
|
46700e5ad0 | ||
|
|
d1e21fa345 |
5
.github/FUNDING.yml
vendored
Normal file
5
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
ko_fi: pixelpawsai
|
||||
patreon: PixelPawsAI
|
||||
custom: ['paypal.me/pixelpawsai', 'https://afdian.com/a/pixelpawsai']
|
||||
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Always use English for comments.
|
||||
69
.github/workflows/backend-tests.yml
vendored
Normal file
69
.github/workflows/backend-tests.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: Backend Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
paths:
|
||||
- 'py/**'
|
||||
- 'standalone.py'
|
||||
- 'tests/**'
|
||||
- 'requirements.txt'
|
||||
- 'requirements-dev.txt'
|
||||
- 'pyproject.toml'
|
||||
- 'pytest.ini'
|
||||
- '.github/workflows/backend-tests.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'py/**'
|
||||
- 'standalone.py'
|
||||
- 'tests/**'
|
||||
- 'requirements.txt'
|
||||
- 'requirements-dev.txt'
|
||||
- 'pyproject.toml'
|
||||
- 'pytest.ini'
|
||||
- '.github/workflows/backend-tests.yml'
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
name: Run pytest with coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
requirements.txt
|
||||
requirements-dev.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
- name: Run pytest with coverage
|
||||
env:
|
||||
COVERAGE_FILE: coverage/backend/.coverage
|
||||
run: |
|
||||
mkdir -p coverage/backend
|
||||
python -m pytest \
|
||||
--cov=py \
|
||||
--cov=standalone \
|
||||
--cov-report=term-missing \
|
||||
--cov-report=xml:coverage/backend/coverage.xml \
|
||||
--cov-report=html:coverage/backend/html \
|
||||
--cov-report=json:coverage/backend/coverage.json
|
||||
|
||||
- name: Upload coverage artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: backend-coverage
|
||||
path: coverage/backend
|
||||
if-no-files-found: warn
|
||||
52
.github/workflows/frontend-tests.yml
vendored
Normal file
52
.github/workflows/frontend-tests.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
name: Frontend Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
paths:
|
||||
- 'package.json'
|
||||
- 'package-lock.json'
|
||||
- 'vitest.config.js'
|
||||
- 'tests/frontend/**'
|
||||
- 'static/js/**'
|
||||
- 'scripts/run_frontend_coverage.js'
|
||||
- '.github/workflows/frontend-tests.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'package.json'
|
||||
- 'package-lock.json'
|
||||
- 'vitest.config.js'
|
||||
- 'tests/frontend/**'
|
||||
- 'static/js/**'
|
||||
- 'scripts/run_frontend_coverage.js'
|
||||
- '.github/workflows/frontend-tests.yml'
|
||||
|
||||
jobs:
|
||||
vitest:
|
||||
name: Run Vitest with coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js 20
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: 'npm'
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Run frontend tests with coverage
|
||||
run: npm run test:coverage
|
||||
|
||||
- name: Upload coverage artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: frontend-coverage
|
||||
path: coverage/frontend
|
||||
if-no-files-found: warn
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,5 +1,12 @@
|
||||
__pycache__/
|
||||
settings.json
|
||||
path_mappings.yaml
|
||||
output/*
|
||||
py/run_test.py
|
||||
.vscode/
|
||||
cache/
|
||||
civitai/
|
||||
node_modules/
|
||||
coverage/
|
||||
.coverage
|
||||
model_cache/
|
||||
|
||||
22
AGENTS.md
Normal file
22
AGENTS.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
ComfyUI LoRA Manager pairs a Python backend with browser-side widgets. Backend modules live in <code>py/</code> with HTTP entry points in <code>py/routes/</code>, feature logic in <code>py/services/</code>, shared helpers in <code>py/utils/</code>, and custom nodes in <code>py/nodes/</code>. UI scripts extend ComfyUI from <code>web/comfyui/</code>, while deploy-ready assets remain in <code>static/</code> and <code>templates/</code>. Localization files live in <code>locales/</code>, example workflows in <code>example_workflows/</code>, and interim tests such as <code>test_i18n.py</code> sit beside their source until a dedicated <code>tests/</code> tree lands.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
- <code>pip install -r requirements.txt</code> installs backend dependencies.
|
||||
- <code>python standalone.py --port 8188</code> launches the standalone server for iterative development.
|
||||
- <code>python -m pytest test_i18n.py</code> runs the current regression suite; target new files explicitly, e.g. <code>python -m pytest tests/test_recipes.py</code>.
|
||||
- <code>python scripts/sync_translation_keys.py</code> synchronizes locale keys after UI string updates.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Follow PEP 8 with four-space indentation and descriptive snake_case file and function names such as <code>settings_manager.py</code>. Classes stay PascalCase, constants in UPPER_SNAKE_CASE, and loggers retrieved via <code>logging.getLogger(__name__)</code>. Prefer explicit type hints and docstrings on public APIs. JavaScript under <code>web/comfyui/</code> uses ES modules with camelCase helpers and the <code>_widget.js</code> suffix for UI components.
|
||||
|
||||
## Testing Guidelines
|
||||
Pytest powers backend tests. Name modules <code>test_<feature>.py</code> and keep them near the code or in a future <code>tests/</code> package. Mock ComfyUI dependencies through helpers in <code>standalone.py</code>, keep filesystem fixtures deterministic, and ensure translations are covered. Run <code>python -m pytest</code> before submitting changes.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Commits follow the conventional format, e.g. <code>feat(settings): add default model path</code>, and should stay focused on a single concern. Pull requests must outline the problem, summarize the solution, list manual verification steps (server run, targeted pytest), and link related issues. Include screenshots or GIFs for UI or locale updates and call out migration steps such as <code>settings.json</code> adjustments.
|
||||
|
||||
## Configuration & Localization Tips
|
||||
Copy <code>settings.json.example</code> to <code>settings.json</code> and adapt model directories before running the standalone server. Store reference assets in <code>civitai/</code> or <code>docs/</code> to keep runtime directories deploy-ready. Whenever UI text changes, update every <code>locales/<lang>.json</code> file and rerun the translation sync script so ComfyUI surfaces localized strings.
|
||||
103
IFLOW.md
Normal file
103
IFLOW.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# ComfyUI LoRA Manager - iFlow 上下文
|
||||
|
||||
## 项目概述
|
||||
|
||||
ComfyUI LoRA Manager 是一个全面的工具集,用于简化 ComfyUI 中 LoRA 模型的组织、下载和应用。它提供了强大的功能,如配方管理、检查点组织和一键工作流集成,使模型操作更快、更流畅、更简单。
|
||||
|
||||
该项目是一个 Python 后端与 JavaScript 前端结合的 Web 应用程序,既可以作为 ComfyUI 的自定义节点运行,也可以作为独立应用程序运行。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
D:\Workspace\ComfyUI\custom_nodes\ComfyUI-Lora-Manager\
|
||||
├── py/ # Python 后端代码
|
||||
│ ├── config.py # 全局配置
|
||||
│ ├── lora_manager.py # 主入口点
|
||||
│ ├── controllers/ # 控制器
|
||||
│ ├── metadata_collector/ # 元数据收集器
|
||||
│ ├── middleware/ # 中间件
|
||||
│ ├── nodes/ # ComfyUI 节点
|
||||
│ ├── recipes/ # 配方相关
|
||||
│ ├── routes/ # API 路由
|
||||
│ ├── services/ # 业务逻辑服务
|
||||
│ ├── utils/ # 工具函数
|
||||
│ └── validators/ # 验证器
|
||||
├── static/ # 静态资源 (CSS, JS, 图片)
|
||||
├── templates/ # HTML 模板
|
||||
├── locales/ # 国际化文件
|
||||
├── tests/ # 测试代码
|
||||
├── standalone.py # 独立模式入口
|
||||
├── requirements.txt # Python 依赖
|
||||
├── package.json # Node.js 依赖和脚本
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 核心组件
|
||||
|
||||
### 后端 (Python)
|
||||
|
||||
- **主入口**: `py/lora_manager.py` 和 `standalone.py`
|
||||
- **配置**: `py/config.py` 管理全局配置和路径
|
||||
- **路由**: `py/routes/` 目录下包含各种 API 路由
|
||||
- **服务**: `py/services/` 目录下包含业务逻辑,如模型扫描、下载管理等
|
||||
- **模型管理**: 使用 `ModelServiceFactory` 来管理不同类型的模型 (LoRA, Checkpoint, Embedding)
|
||||
|
||||
### 前端 (JavaScript)
|
||||
|
||||
- **构建工具**: 使用 Node.js 和 npm 进行依赖管理和测试
|
||||
- **测试**: 使用 Vitest 进行前端测试
|
||||
|
||||
## 构建和运行
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
# Python 依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Node.js 依赖 (用于测试)
|
||||
npm install
|
||||
```
|
||||
|
||||
### 运行 (ComfyUI 模式)
|
||||
|
||||
作为 ComfyUI 的自定义节点安装后,在 ComfyUI 中启动即可。
|
||||
|
||||
### 运行 (独立模式)
|
||||
|
||||
```bash
|
||||
# 使用默认配置运行
|
||||
python standalone.py
|
||||
|
||||
# 指定主机和端口
|
||||
python standalone.py --host 127.0.0.1 --port 9000
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
#### 后端测试
|
||||
|
||||
```bash
|
||||
# 安装开发依赖
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# 运行测试
|
||||
pytest
|
||||
```
|
||||
|
||||
#### 前端测试
|
||||
|
||||
```bash
|
||||
# 运行测试
|
||||
npm run test
|
||||
|
||||
# 运行测试并生成覆盖率报告
|
||||
npm run test:coverage
|
||||
```
|
||||
|
||||
## 开发约定
|
||||
|
||||
- **代码风格**: Python 代码应遵循 PEP 8 规范
|
||||
- **测试**: 新功能应包含相应的单元测试
|
||||
- **配置**: 使用 `settings.json` 文件进行用户配置
|
||||
- **日志**: 使用 Python 标准库 `logging` 模块进行日志记录
|
||||
687
LICENSE
687
LICENSE
@@ -1,21 +1,674 @@
|
||||
MIT License
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (c) 2023 Will Miao
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
Preamble
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
ComfyUI Lora Manager - A ComfyUI custom node for managing models
|
||||
Copyright (C) 2025 Will Miao
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
ComfyUI Lora Manager Copyright (C) 2025 Will Miao
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
237
README.md
237
README.md
@@ -13,57 +13,72 @@ A comprehensive toolset that streamlines organizing, downloading, and applying L
|
||||
## 📺 Tutorial: One-Click LoRA Integration
|
||||
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
|
||||
|
||||
[](https://youtu.be/qS95OjX3e70)
|
||||
[](https://youtu.be/noN7f_ER7yo)
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
|
||||
## 🌐 Browser Extension
|
||||
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
|
||||
|
||||

|
||||
|
||||
<div>
|
||||
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
|
||||
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div>
|
||||
|
||||
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
|
||||
|
||||
---
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.8.6 Major Update
|
||||
* **Checkpoint Management** - Added comprehensive management for model checkpoints including scanning, searching, filtering, and deletion
|
||||
* **Enhanced Metadata Support** - New capabilities for retrieving and managing checkpoint metadata with improved operations
|
||||
* **Improved Initial Loading** - Optimized cache initialization with visual progress indicators for better user experience
|
||||
### v0.9.10
|
||||
* **Smarter Update Matching** - Users can now choose to check and group updates by matching base model only or with no base-model constraint; version lists also support toggling between same-base versions or all versions.
|
||||
* **Flexible Tag Filtering** - The filter panel now supports tag exclusion: click a tag to include, click again to exclude, and click a third time to clear, enabling stronger and more flexible tag filters.
|
||||
* **License Visibility & Controls** - Model detail headers and ComfyUI preview popups now show Civitai license icons. The filter panel gains license include/exclude options, and a new global context menu action, "Refresh license metadata," fetches missing license data.
|
||||
* **Recipe Improvements** - Recipes now allow importing with zero LoRAs, and recipe detail pages show the related checkpoint for easier reference.
|
||||
* **Better ZIP Downloads** - When downloading models packaged in ZIPs, model files are extracted into the target model folder; ZIPs containing multiple model files (e.g., WanVideo high/low LoRA pairs) are added as separate models.
|
||||
* **Template Workflow Update** - Refreshed the "Illustrious Pony Example" template workflow with usage guidance for each LoRA Manager node.
|
||||
* **Bug Fixes & Stability** - General fixes and stability improvements.
|
||||
|
||||
### v0.8.5
|
||||
* **Enhanced LoRA & Recipe Connectivity** - Added Recipes tab in LoRA details to see all recipes using a specific LoRA
|
||||
* **Improved Navigation** - New shortcuts to jump between related LoRAs and Recipes with one-click navigation
|
||||
* **Video Preview Controls** - Added "Autoplay Videos on Hover" setting to optimize performance and reduce resource usage
|
||||
* **UI Experience Refinements** - Smoother transitions between related content pages
|
||||
### v0.9.9
|
||||
* **Check for Updates Feature** - Users can now check for updates for all models or selected models in bulk mode. Models with available updates will display an "update available" badge on their model card, and users can filter to show only models with updates.
|
||||
* **Model Versions Management** - Added a new Versions tab in the model modal that centralizes all versions of a model, providing download, delete, and ignore update functions.
|
||||
* **Send Checkpoint to ComfyUI** - Users can now click the send button on a checkpoint card to send the checkpoint directly to the current workflow's checkpoint or diffusion model loader node in ComfyUI.
|
||||
* **Customizable Model Card Display** - Added a new setting that allows users to choose whether to display the model name or filename on model cards.
|
||||
* **New Path Template Placeholders** - Added new path template placeholders: `{model_name}` and `{version_name}` for more flexible organization.
|
||||
* **ComfyUI Auto Path Correction Setting** - Added a new setting within ComfyUI to enable or disable the auto path correction feature.
|
||||
|
||||
### v0.8.4
|
||||
* **Node Layout Improvements** - Fixed layout issues with LoRA Loader and Trigger Words Toggle nodes in newer ComfyUI frontend versions
|
||||
* **Recipe LoRA Reconnection** - Added ability to reconnect deleted LoRAs in recipes by clicking the "deleted" badge in recipe details
|
||||
* **Bug Fixes & Stability** - Resolved various issues for improved reliability
|
||||
### v0.9.8
|
||||
* **Full CivArchive API Support** - Added complete support for the CivArchive API as a fallback metadata source beyond Civitai API. Models deleted from Civitai can now still retrieve metadata through the CivArchive API.
|
||||
* **Download Models from CivArchive** - Added support for downloading models directly from CivArchive, similar to downloading from Civitai. Simply click the Download button and paste the model URL to download the corresponding model.
|
||||
* **Custom Priority Tags** - Introduced Custom Priority Tags feature, allowing users to define custom priority tags. These tags will appear as suggestions when editing tags or during auto organization/download using default paths, providing more precise and controlled folder organization. [Guide](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Priority-Tags-Configuration-Guide)
|
||||
* **Drag and Drop Tag Reordering** - Added drag and drop functionality to reorder tags in the tags edit mode for improved usability.
|
||||
* **Download Control in Example Images Panel** - Added stop control in the Download Example Images Panel for better download management.
|
||||
* **Prompt (LoraManager) Node with Autocomplete** - Added new Prompt (LoraManager) node with autocomplete feature for adding embeddings.
|
||||
* **Lora Manager Nodes in Subgraphs** - Lora Manager nodes now support being placed within subgraphs for more flexible workflow organization.
|
||||
|
||||
### v0.8.3
|
||||
* **Enhanced Workflow Parser** - Rebuilt workflow analysis engine with improved support for ComfyUI core nodes and easier extensibility
|
||||
* **Improved Recipe System** - Refined the experimental Save Recipe functionality with better workflow integration
|
||||
* **New Save Image Node** - Added experimental node with metadata support for perfect CivitAI compatibility
|
||||
* Supports dynamic filename prefixes with variables [1](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData?tab=readme-ov-file#filename_prefix)
|
||||
* **Default LoRA Root Setting** - Added configuration option for setting your preferred LoRA directory
|
||||
### v0.9.6
|
||||
* **Metadata Archive Database Support** - Added the ability to download and utilize a metadata archive database, enabling access to metadata for models that have been deleted from CivitAI.
|
||||
* **App-Level Proxy Settings** - Introduced support for configuring a global proxy within the application, making it easier to use the manager behind network restrictions.
|
||||
* **Bug Fixes** - Various bug fixes for improved stability and reliability.
|
||||
|
||||
### v0.8.2
|
||||
* **Faster Initialization for Forge Users** - Improved first-run efficiency by utilizing existing `.json` and `.civitai.info` files from Forge’s CivitAI helper extension, making migration smoother.
|
||||
* **LoRA Filename Editing** - Added support for renaming LoRA files directly within LoRA Manager.
|
||||
* **Recipe Editing** - Users can now edit recipe names and tags.
|
||||
* **Retain Deleted LoRAs in Recipes** - Deleted LoRAs will remain listed in recipes, allowing future functionality to reconnect them once re-obtained.
|
||||
* **Download Missing LoRAs from Recipes** - Easily fetch missing LoRAs associated with a recipe.
|
||||
### v0.9.2
|
||||
* **Bulk Auto-Organization Action** - Added a new bulk auto-organization feature. You can now select multiple models and automatically organize them according to your current path template settings for streamlined management.
|
||||
* **Bug Fixes** - Addressed several bugs to improve stability and reliability.
|
||||
|
||||
### v0.8.1
|
||||
* **Base Model Correction** - Added support for modifying base model associations to fix incorrect metadata for non-CivitAI LoRAs
|
||||
* **LoRA Loader Flexibility** - Made CLIP input optional for model-only workflows like Hunyuan video generation
|
||||
* **Expanded Recipe Support** - Added compatibility with 3 additional recipe metadata formats
|
||||
* **Enhanced Showcase Images** - Generation parameters now displayed alongside LoRA preview images
|
||||
* **UI Improvements & Bug Fixes** - Various interface refinements and stability enhancements
|
||||
### v0.9.1
|
||||
* **Enhanced Bulk Operations** - Improved bulk operations with Marquee Selection and a bulk operation context menu, providing a more intuitive, desktop-application-like user experience.
|
||||
* **New Bulk Actions** - Added bulk operations for adding tags and setting base models to multiple models simultaneously.
|
||||
|
||||
### v0.8.0
|
||||
* **Introduced LoRA Recipes** - Create, import, save, and share your favorite LoRA combinations
|
||||
* **Recipe Management System** - Easily browse, search, and organize your LoRA recipes
|
||||
* **Workflow Integration** - Save recipes directly from your workflow with generation parameters preserved
|
||||
* **Simplified Workflow Application** - Quickly apply saved recipes to new projects
|
||||
* **Enhanced UI & UX** - Improved interface design and user experience
|
||||
* **Bug Fixes & Stability** - Resolved various issues and enhanced overall performance
|
||||
### v0.9.0
|
||||
* **UI Overhaul for Enhanced Navigation** - Replaced the top flat folder tags with a new folder sidebar and breadcrumb navigation system for a more intuitive folder browsing and selection experience.
|
||||
* **Dual-Mode Folder Sidebar** - The new folder sidebar offers two display modes: 'List Mode,' which mirrors the classic folder view, and 'Tree Mode,' which presents a hierarchical folder structure for effortless navigation through nested directories.
|
||||
* **Internationalization Support** - Introduced multi-language support, now available in English, Simplified Chinese, Traditional Chinese, Spanish, Japanese, Korean, French, Russian, and German. Feedback from native speakers is welcome to improve the translations.
|
||||
* **Automatic Filename Conflict Resolution** - Implemented automatic file renaming (`original name + short hash`) to prevent conflicts when downloading or moving models.
|
||||
* **Performance Optimizations & Bug Fixes** - Various performance improvements and bug fixes for a more stable and responsive experience.
|
||||
|
||||
[View Update History](./update_logs.md)
|
||||
|
||||
@@ -82,13 +97,6 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
- 🚀 **High Performance**
|
||||
- Fast model loading and browsing
|
||||
- Smooth scrolling through large collections
|
||||
- Real-time updates when files change
|
||||
|
||||
- 📂 **Advanced Organization**
|
||||
- Quick search with fuzzy matching
|
||||
- Folder-based categorization
|
||||
- Move LoRAs between folders
|
||||
- Sort by name or date
|
||||
|
||||
- 🌐 **Rich Model Integration**
|
||||
- Direct download from CivitAI
|
||||
@@ -120,19 +128,28 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
|
||||
## Installation
|
||||
|
||||
### Option 1: **ComfyUI Manager** (Recommended)
|
||||
### Option 1: **ComfyUI Manager** (Recommended for ComfyUI users)
|
||||
|
||||
1. Open **ComfyUI**.
|
||||
2. Go to **Manager > Custom Node Manager**.
|
||||
3. Search for `lora-manager`.
|
||||
4. Click **Install**.
|
||||
|
||||
### Option 2: **Manual Installation**
|
||||
### Option 2: **Portable Standalone Edition** (No ComfyUI required)
|
||||
|
||||
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.9.8/lora_manager_portable.7z)
|
||||
2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder.
|
||||
3. Edit the new `settings.json` to include your correct model folder paths and CivitAI API key
|
||||
- Set `"use_portable_settings": true` if you want the configuration to remain inside the repository folder instead of your user settings directory.
|
||||
4. Run run.bat
|
||||
- To change the startup port, edit `run.bat` and modify the parameter (e.g. `--port 9001`)
|
||||
|
||||
### Option 3: **Manual Installation**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/willmiao/ComfyUI-Lora-Manager.git
|
||||
cd ComfyUI-Lora-Manager
|
||||
pip install requirements.txt
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -153,32 +170,138 @@ pip install requirements.txt
|
||||
- Paste into the Lora Loader node's text input
|
||||
- The node will automatically apply preset strength and trigger words
|
||||
|
||||
### Filename Format Patterns for Save Image Node
|
||||
|
||||
The Save Image Node supports dynamic filename generation using pattern codes. You can customize how your images are named using the following format patterns:
|
||||
|
||||
#### Available Pattern Codes
|
||||
|
||||
- `%seed%` - Inserts the generation seed number
|
||||
- `%width%` - Inserts the image width
|
||||
- `%height%` - Inserts the image height
|
||||
- `%pprompt:N%` - Inserts the positive prompt (limited to N characters)
|
||||
- `%nprompt:N%` - Inserts the negative prompt (limited to N characters)
|
||||
- `%model:N%` - Inserts the model/checkpoint name (limited to N characters)
|
||||
- `%date%` - Inserts current date/time as "yyyyMMddhhmmss"
|
||||
- `%date:FORMAT%` - Inserts date using custom format with:
|
||||
- `yyyy` - 4-digit year
|
||||
- `yy` - 2-digit year
|
||||
- `MM` - 2-digit month
|
||||
- `dd` - 2-digit day
|
||||
- `hh` - 2-digit hour
|
||||
- `mm` - 2-digit minute
|
||||
- `ss` - 2-digit second
|
||||
|
||||
#### Examples
|
||||
|
||||
- `image_%seed%` → `image_1234567890`
|
||||
- `gen_%width%x%height%` → `gen_512x768`
|
||||
- `%model:10%_%seed%` → `dreamshape_1234567890`
|
||||
- `%date:yyyy-MM-dd%` → `2025-04-28`
|
||||
- `%pprompt:20%_%seed%` → `beautiful landscape_1234567890`
|
||||
- `%model%_%date:yyMMdd%_%seed%` → `dreamshaper_v8_250428_1234567890`
|
||||
|
||||
You can combine multiple patterns to create detailed, organized filenames for your generated images.
|
||||
|
||||
### Standalone Mode
|
||||
|
||||
You can now run LoRA Manager independently from ComfyUI:
|
||||
|
||||
1. **For ComfyUI users**:
|
||||
- Launch ComfyUI with LoRA Manager at least once to initialize the necessary path information in the `settings.json` file located in your user settings folder (see paths above).
|
||||
- Make sure dependencies are installed: `pip install -r requirements.txt`
|
||||
- From your ComfyUI root directory, run:
|
||||
```bash
|
||||
python custom_nodes\comfyui-lora-manager\standalone.py
|
||||
```
|
||||
- Access the interface at: `http://localhost:8188/loras`
|
||||
- You can specify a different host or port with arguments:
|
||||
```bash
|
||||
python custom_nodes\comfyui-lora-manager\standalone.py --host 127.0.0.1 --port 9000
|
||||
```
|
||||
|
||||
2. **For non-ComfyUI users**:
|
||||
- Copy the provided `settings.json.example` file to create a new file named `settings.json`. Update the API key, optional language, and folder paths only—the library registry is created automatically when LoRA Manager starts.
|
||||
- Edit `settings.json` to include your correct model folder paths and CivitAI API key (you can leave the defaults until ready to configure them)
|
||||
- Enable portable mode by setting `"use_portable_settings": true` if you prefer LoRA Manager to read and write the `settings.json` located in the project directory.
|
||||
- Install required dependencies: `pip install -r requirements.txt`
|
||||
- Run standalone mode:
|
||||
```bash
|
||||
python standalone.py
|
||||
```
|
||||
- Access the interface through your browser at: `http://localhost:8188/loras`
|
||||
|
||||
> **Note:** Existing installations automatically migrate the legacy `settings.json` from the plugin folder to the user settings directory the first time you launch this version.
|
||||
|
||||
This standalone mode provides a lightweight option for managing your model and recipe collection without needing to run the full ComfyUI environment, making it useful even for users who primarily use other stable diffusion interfaces.
|
||||
|
||||
## Testing & Coverage
|
||||
|
||||
### Backend
|
||||
|
||||
Install the development dependencies and run pytest with coverage reports:
|
||||
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
COVERAGE_FILE=coverage/backend/.coverage pytest \
|
||||
--cov=py \
|
||||
--cov=standalone \
|
||||
--cov-report=term-missing \
|
||||
--cov-report=html:coverage/backend/html \
|
||||
--cov-report=xml:coverage/backend/coverage.xml \
|
||||
--cov-report=json:coverage/backend/coverage.json
|
||||
```
|
||||
|
||||
HTML, XML, and JSON artifacts are stored under `coverage/backend/` so you can inspect hot spots locally or from CI artifacts.
|
||||
|
||||
### Frontend
|
||||
|
||||
Run the Vitest coverage suite to analyze widget hot spots:
|
||||
|
||||
```bash
|
||||
npm run test:coverage
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
Thank you for your interest in contributing to ComfyUI LoRA Manager! As this project is currently in its early stages and undergoing rapid development and refactoring, we are temporarily not accepting pull requests.
|
||||
|
||||
However, your feedback and ideas are extremely valuable to us:
|
||||
- Please feel free to open issues for any bugs you encounter
|
||||
- Submit feature requests through GitHub issues
|
||||
- Share your suggestions for improvements
|
||||
|
||||
We appreciate your understanding and look forward to potentially accepting code contributions once the project architecture stabilizes.
|
||||
|
||||
---
|
||||
|
||||
## Credits
|
||||
|
||||
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
||||
|
||||
- [ComfyUI-SaveImageWithMetaData](https://github.com/Comfy-Community/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
||||
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
||||
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have suggestions, bug reports, or improvements, feel free to open an issue or contribute directly to the codebase. Pull requests are always welcome!
|
||||
|
||||
---
|
||||
|
||||
## ☕ Support
|
||||
|
||||
If you find this project helpful, consider supporting its development:
|
||||
|
||||
[](https://ko-fi.com/pixelpawsai)
|
||||
|
||||
[](https://patreon.com/PixelPawsAI)
|
||||
|
||||
WeChat: [Click to view QR code](https://raw.githubusercontent.com/willmiao/ComfyUI-Lora-Manager/main/static/images/wechat-qr.webp)
|
||||
|
||||
## 💬 Community
|
||||
|
||||
Join our Discord community for support, discussions, and updates:
|
||||
[Discord Server](https://discord.gg/vcqNrWVFvM)
|
||||
|
||||
---
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#willmiao/ComfyUI-Lora-Manager&Date)
|
||||
|
||||
46
__init__.py
46
__init__.py
@@ -1,18 +1,52 @@
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
try: # pragma: no cover - import fallback for pytest collection
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.prompt import PromptLoraManager
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
from .py.nodes.debug_metadata import DebugMetadata
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
except ImportError: # pragma: no cover - allows running under pytest without package install
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
package_root = pathlib.Path(__file__).resolve().parent
|
||||
if str(package_root) not in sys.path:
|
||||
sys.path.append(str(package_root))
|
||||
|
||||
PromptLoraManager = importlib.import_module("py.nodes.prompt").PromptLoraManager
|
||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
|
||||
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
|
||||
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
|
||||
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
||||
SaveImage = importlib.import_module("py.nodes.save_image").SaveImage
|
||||
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
||||
WanVideoLoraSelect = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelect
|
||||
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
|
||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
PromptLoraManager.NAME: PromptLoraManager,
|
||||
LoraManagerLoader.NAME: LoraManagerLoader,
|
||||
LoraManagerTextLoader.NAME: LoraManagerTextLoader,
|
||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||
LoraStacker.NAME: LoraStacker,
|
||||
SaveImage.NAME: SaveImage
|
||||
SaveImage.NAME: SaveImage,
|
||||
DebugMetadata.NAME: DebugMetadata,
|
||||
WanVideoLoraSelect.NAME: WanVideoLoraSelect,
|
||||
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
|
||||
# Initialize metadata collector
|
||||
init_metadata_collector()
|
||||
|
||||
# Register routes on import
|
||||
LoraManager.add_routes()
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
||||
|
||||
180
docs/LM-Extension-Wiki.md
Normal file
180
docs/LM-Extension-Wiki.md
Normal file
@@ -0,0 +1,180 @@
|
||||
## Overview
|
||||
|
||||
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com).
|
||||
It also supports browsing on [CivArchive](https://civarchive.com/) (formerly CivitaiArchive).
|
||||
|
||||
With this extension, you can:
|
||||
|
||||
✅ Instantly see which models are already present in your local library
|
||||
✅ Download new models with a single click
|
||||
✅ Manage downloads efficiently with queue and parallel download support
|
||||
✅ Keep your downloaded models automatically organized according to your custom settings
|
||||
|
||||

|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why Are All Features for Supporters Only?
|
||||
|
||||
I love building tools for the Stable Diffusion and ComfyUI communities, and LoRA Manager is a passion project that I've poured countless hours into. When I created this companion extension, my hope was to offer its core features for free, as a thank-you to all of you.
|
||||
|
||||
Unfortunately, I've reached a point where I need to be realistic. The level of support from the free model has been far lower than what's needed to justify the continuous development and maintenance for both projects. It was a difficult decision, but I've chosen to make the extension's features exclusive to supporters.
|
||||
|
||||
This change is crucial for me to be able to continue dedicating my time to improving the free and open-source LoRA Manager, which I'm committed to keeping available for everyone.
|
||||
|
||||
Your support does more than just unlock a few features—it allows me to keep innovating and ensures the core LoRA Manager project thrives. I'm incredibly grateful for your understanding and any support you can offer. ❤️
|
||||
|
||||
(_For those who previously supported me on Ko-fi with a one-time donation, I'll be sending out license keys individually as a thank-you._)
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
### Supported Browsers & Installation Methods
|
||||
|
||||
| Browser | Installation Method |
|
||||
|--------------------|-------------------------------------------------------------------------------------|
|
||||
| **Google Chrome** | [Chrome Web Store link](https://chromewebstore.google.com/detail/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) |
|
||||
| **Microsoft Edge** | Install via Chrome Web Store (compatible) |
|
||||
| **Brave Browser** | Install via Chrome Web Store (compatible) |
|
||||
| **Opera** | Install via Chrome Web Store (compatible) |
|
||||
| **Firefox** | <div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div> |
|
||||
|
||||
For non-Chrome browsers (e.g., Microsoft Edge), you can typically install extensions from the Chrome Web Store by following these steps: open the extension’s Chrome Web Store page, click 'Get extension', then click 'Allow' when prompted to enable installations from other stores, and finally click 'Add extension' to complete the installation.
|
||||
|
||||
---
|
||||
|
||||
## Privacy & Security
|
||||
|
||||
I understand concerns around browser extensions and privacy, and I want to be fully transparent about how the **LM Civitai Extension** works:
|
||||
|
||||
- **Reviewed and Verified**
|
||||
This extension has been **manually reviewed and approved by the Chrome Web Store**. The Firefox version uses the **exact same code** (only the packaging format differs) and has passed **Mozilla’s Add-on review**.
|
||||
|
||||
- **Minimal Network Access**
|
||||
The only external server this extension connects to is:
|
||||
**`https://willmiao.shop`** — used solely for **license validation**.
|
||||
|
||||
It does **not collect, transmit, or store any personal or usage data**.
|
||||
No browsing history, no user IDs, no analytics, no hidden trackers.
|
||||
|
||||
- **Local-Only Model Detection**
|
||||
Model detection and LoRA Manager communication all happen **locally** within your browser, directly interacting with your local LoRA Manager backend.
|
||||
|
||||
I value your trust and are committed to keeping your local setup private and secure. If you have any questions, feel free to reach out!
|
||||
|
||||
---
|
||||
|
||||
## How to Use
|
||||
|
||||
After installing the extension, you'll automatically receive a **7-day trial** to explore all features.
|
||||
|
||||
When the extension is correctly installed and your license is valid:
|
||||
|
||||
- Open **Civitai**, and you'll see visual indicators added by the extension on model cards, showing:
|
||||
- ✅ Models already present in your local library
|
||||
- ⬇️ A download button for models not in your library
|
||||
|
||||
Clicking the download button adds the corresponding model version to the download queue, waiting to be downloaded. You can set up to **5 models to download simultaneously**.
|
||||
|
||||
### Visual Indicators Appear On:
|
||||
|
||||
- **Home Page** — Featured models
|
||||
- **Models Page**
|
||||
- **Creator Profiles** — If the creator has set their models to be visible
|
||||
- **Recommended Resources** — On individual model pages
|
||||
|
||||
### Version Buttons on Model Pages
|
||||
|
||||
On a specific model page, visual indicators also appear on version buttons, showing which versions are already in your local library.
|
||||
|
||||
When switching to a specific version by clicking a version button:
|
||||
|
||||
- Clicking the download button will open a dropdown:
|
||||
- Download via **LoRA Manager**
|
||||
- Download via **Original Download** (browser download)
|
||||
|
||||
You can check **Remember my choice** to set your preferred default. You can change this setting anytime in the extension's settings.
|
||||
|
||||

|
||||
|
||||
### Resources on Image Pages (2025-08-05) — now shows in-library indicators for image resources. ‘Import image as recipe’ coming soon!
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Model Download Location & LoRA Manager Settings
|
||||
|
||||
To use the **one-click download function**, you must first set:
|
||||
|
||||
- Your **Default LoRAs Root**
|
||||
- Your **Default Checkpoints Root**
|
||||
|
||||
These are set within LoRA Manager's settings.
|
||||
|
||||
When everything is configured, downloaded model files will be placed in:
|
||||
|
||||
`<Default_Models_Root>/<Base_Model_of_the_Model>/<First_Tag_of_the_Model>`
|
||||
|
||||
|
||||
### Update: Default Path Customization (2025-07-21)
|
||||
|
||||
A new setting to customize the default download path has been added in the nightly version. You can now personalize where models are saved when downloading via the LM Civitai Extension.
|
||||
|
||||

|
||||
|
||||
The previous YAML path mapping file will be deprecated—settings will now be unified in settings.json to simplify configuration.
|
||||
|
||||
---
|
||||
|
||||
## Backend Port Configuration
|
||||
|
||||
If your **ComfyUI** or **LoRA Manager** backend is running on a port **other than the default 8188**, you must configure the backend port in the extension's settings.
|
||||
|
||||
After correctly setting and saving the port, you'll see in the extension's header area:
|
||||
- A **Healthy** status with the tooltip: `Connected to LoRA Manager on port xxxx`
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Connecting to a Remote LoRA Manager
|
||||
|
||||
If your LoRA Manager is running on another computer, you can still connect from your browser using port forwarding.
|
||||
|
||||
> **Why can't you set a remote IP directly?**
|
||||
>
|
||||
> For privacy and security, the extension only requests access to `http://127.0.0.1/*`. Supporting remote IPs would require much broader permissions, which may be rejected by browser stores and could raise user concerns.
|
||||
|
||||
**Solution: Port Forwarding with `socat`**
|
||||
|
||||
On your browser computer, run:
|
||||
|
||||
`socat TCP-LISTEN:8188,bind=127.0.0.1,fork TCP:REMOTE.IP.ADDRESS.HERE:8188`
|
||||
|
||||
- Replace `REMOTE.IP.ADDRESS.HERE` with the IP of the machine running LoRA Manager.
|
||||
- Adjust the port if needed.
|
||||
|
||||
This lets the extension connect to `127.0.0.1:8188` as usual, with traffic forwarded to your remote server.
|
||||
|
||||
_Thanks to user **Temikus** for sharing this solution!_
|
||||
|
||||
---
|
||||
|
||||
## Roadmap
|
||||
|
||||
The extension will evolve alongside **LoRA Manager** improvements. Planned features include:
|
||||
|
||||
- [x] Support for **additional model types** (e.g., embeddings)
|
||||
- [ ] One-click **Recipe Import**
|
||||
- [x] Display of in-library status for all resources in the **Resources Used** section of the image page
|
||||
- [x] One-click **Auto-organize Models**
|
||||
|
||||
**Stay tuned — and thank you for your support!**
|
||||
|
||||
---
|
||||
|
||||
93
docs/architecture/example_images_routes.md
Normal file
93
docs/architecture/example_images_routes.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# Example image route architecture
|
||||
|
||||
The example image routing stack mirrors the layered model route stack described in
|
||||
[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup,
|
||||
handler orchestration, and long-running workflows now live in clearly separated modules so
|
||||
we can extend download/import behaviour without touching the entire feature surface.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ExampleImagesHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Download manager / processor / file manager]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Filesystem]
|
||||
F1 --> G2[Model metadata]
|
||||
F1 --> G3[WebSocket progress]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. |
|
||||
| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. |
|
||||
| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. |
|
||||
| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. |
|
||||
| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}`
|
||||
mapping consumed by the registrar. The table below outlines how each handler collaborates
|
||||
with the use cases and utilities.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. |
|
||||
| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. |
|
||||
| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. |
|
||||
| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Shared progress snapshots** - The download handler returns the same snapshot built by
|
||||
`DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket
|
||||
progress events.
|
||||
* **Safe filesystem access** - All folder/file actions flow through
|
||||
`ExampleImagesFileManager`, which validates the configured example image root and ensures
|
||||
responses never leak absolute paths outside the allowed directory.
|
||||
* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`,
|
||||
which updates model metadata via `MetadataManager` and notifies the relevant scanners so
|
||||
cache state stays in sync.
|
||||
|
||||
## Migration notes
|
||||
|
||||
The refactor brings the example image stack in line with the model/recipe stacks:
|
||||
|
||||
1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects
|
||||
should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring
|
||||
handler callables.
|
||||
2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like
|
||||
`BaseModelRoutes`. If you previously instantiated handlers directly, inject custom
|
||||
collaborators via the controller constructor (`download_manager`, `processor`,
|
||||
`file_manager`) to keep test seams predictable.
|
||||
3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching
|
||||
`DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers
|
||||
expect those abstractions to surface validation/concurrency errors, and bypassing them
|
||||
will skip the HTTP-friendly error mapping.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`.
|
||||
2. Expose the coroutine on an existing handler class (or create a new handler and extend
|
||||
`ExampleImagesHandlerSet`).
|
||||
3. Wire additional services or factories inside `_build_handler_set` on
|
||||
`ExampleImagesRoutes`, mirroring how the model stack introduces new use cases.
|
||||
|
||||
`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause
|
||||
flows, and import validations. Use it as a template when introducing new handler
|
||||
collaborators or error mappings.
|
||||
100
docs/architecture/model_routes.md
Normal file
100
docs/architecture/model_routes.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# Base model route architecture
|
||||
|
||||
The model routing stack now splits HTTP wiring, orchestration logic, and
|
||||
business rules into discrete layers. The goal is to make it obvious where a
|
||||
new collaborator should live and which contract it must honour. The diagram
|
||||
below captures the end-to-end flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ModelHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & metadata]
|
||||
F1 --> G2[Filesystem]
|
||||
F1 --> G3[WebSocket state]
|
||||
end
|
||||
```
|
||||
|
||||
Every box maps to a concrete module:
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. |
|
||||
| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. |
|
||||
| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). |
|
||||
| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. |
|
||||
| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. |
|
||||
|
||||
## Handler responsibilities & contracts
|
||||
|
||||
`ModelHandlerSet` flattens the handler objects into the exact callables used by
|
||||
the registrar. The table below highlights the separation of concerns within
|
||||
the set and the invariants that must hold after each handler returns.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
||||
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
||||
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
||||
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
||||
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
||||
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
||||
| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. |
|
||||
| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
Each use case exposes a narrow asynchronous API that hides the underlying
|
||||
services. Their error mapping is essential for predictable HTTP responses.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. |
|
||||
| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. |
|
||||
| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. |
|
||||
|
||||
## Maintaining legacy contracts
|
||||
|
||||
The refactor preserves the invariants called out in the previous architecture
|
||||
notes. The most critical ones are reiterated here to emphasise the
|
||||
collaboration points:
|
||||
|
||||
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
||||
channelled through `ModelManagementHandler`. The handler delegates to
|
||||
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
|
||||
mutated in-place before the handler returns. The accompanying tests assert
|
||||
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
|
||||
each mutation.
|
||||
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
||||
asset, `MetadataSyncService` persists the JSON metadata, and
|
||||
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
||||
the static URL produced by `config.get_preview_static_url`, keeping browser
|
||||
clients in lockstep with disk state.
|
||||
3. **Download progress** – `DownloadCoordinator.schedule_download` generates the
|
||||
download identifier, registers a WebSocket progress callback, and caches the
|
||||
latest numeric progress via `WebSocketManager`. Both `download_model`
|
||||
responses and `/download-progress/{id}` polling read from the same cache to
|
||||
guarantee consistent progress reporting across transports.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
To add a new shared route:
|
||||
|
||||
1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name.
|
||||
2. Implement the corresponding coroutine on one of the handlers inside
|
||||
`ModelHandlerSet` (or introduce a new handler class when the concern does not
|
||||
fit existing ones).
|
||||
3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by
|
||||
wiring services or use cases through the constructor parameters.
|
||||
|
||||
Model-specific routes should continue to be registered inside the subclass
|
||||
implementation of `setup_specific_routes`, reusing the shared registrar where
|
||||
possible.
|
||||
34
docs/architecture/multi_library_design.md
Normal file
34
docs/architecture/multi_library_design.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Multi-Library Management for Standalone Mode
|
||||
|
||||
## Requirements Summary
|
||||
- **Independent libraries**: In standalone mode, users can maintain multiple libraries, where each library represents a distinct set of model folders (LoRAs, checkpoints, embeddings, etc.). Only one library is active at any given time, but users need a fast way to switch between them.
|
||||
- **Library-specific settings**: The fields that vary per library are `folder_paths`, `default_lora_root`, `default_checkpoint_root`, and `default_embedding_root` inside `settings.json`.
|
||||
- **Persistent caches**: Every library must have its own SQLite persistent model cache so that metadata generated for one library does not leak into another.
|
||||
- **Backward compatibility**: Existing single-library setups should continue to work. When no multi-library configuration is provided, the application should behave exactly as before.
|
||||
|
||||
## Proposed Design
|
||||
1. **Library registry**
|
||||
- Extend the standalone configuration to hold a list of libraries, each identified by a unique name.
|
||||
- Each entry stores the folder path configuration plus any library-scoped metadata (e.g. creation time, display name).
|
||||
- The active library key is stored separately to allow quick switching without rewriting the full config.
|
||||
2. **Settings management**
|
||||
- Update `settings_manager` to load and persist the library registry. When a library is activated, hydrate the in-memory settings object with that library's folder configuration.
|
||||
- Provide helper methods for creating, renaming, and deleting libraries, ensuring validation for duplicate names and path collisions.
|
||||
- Continue writing the active library settings to `settings.json` for compatibility, while storing the registry in a new section such as `libraries`.
|
||||
3. **Persistent model cache**
|
||||
- Derive the SQLite file path from the active library, e.g. `model_cache_<library>.sqlite` or a nested directory structure like `model_cache/<library>/models.sqlite`.
|
||||
- Update `PersistentModelCache` so it resolves the database path dynamically whenever the active library changes. Ensure connections are closed before switching to avoid locking issues.
|
||||
- Migrate existing single cache files by treating them as the default library's cache.
|
||||
4. **Model scanning workflow**
|
||||
- Modify `ModelScanner` and related services to react to library switches by clearing in-memory caches, re-reading folder paths, and rehydrating metadata from the library-specific SQLite cache.
|
||||
- Provide API endpoints in standalone mode to list libraries, activate one, and trigger a rescan.
|
||||
5. **UI/UX considerations**
|
||||
- In the standalone UI, introduce a library selector component that surfaces available libraries and offers quick switching.
|
||||
- Offer feedback when switching libraries (e.g. spinner while rescanning) and guard destructive actions with confirmation prompts.
|
||||
|
||||
## Implementation Notes
|
||||
- **Data migration**: On startup, detect if the old `settings.json` structure is present. If so, create a default library entry using the current folder paths and point the active library to it.
|
||||
- **Thread safety**: Ensure that any long-running scans are cancelled or awaited before switching libraries to prevent race conditions in cache writes.
|
||||
- **Testing**: Add unit tests for the settings manager to cover library CRUD operations and cache path resolution. Include integration tests that simulate switching libraries and verifying that the correct models are loaded.
|
||||
- **Documentation**: Update user guides to explain how to define libraries, switch between them, and where the new cache files are stored.
|
||||
- **Extensibility**: Keep the design open to future per-library settings (e.g. auto-refresh intervals, metadata overrides) by storing library data as objects instead of flat maps.
|
||||
89
docs/architecture/recipe_routes.md
Normal file
89
docs/architecture/recipe_routes.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Recipe route architecture
|
||||
|
||||
The recipe routing stack now mirrors the modular model route design. HTTP
|
||||
bindings, controller wiring, handler orchestration, and business rules live in
|
||||
separate layers so new behaviours can be added without re-threading the entire
|
||||
feature. The diagram below outlines the flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[RecipeHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & fingerprint index]
|
||||
F1 --> G2[Metadata files]
|
||||
F1 --> G3[Temporary shares]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. |
|
||||
| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. |
|
||||
| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. |
|
||||
| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`RecipeHandlerSet` flattens purpose-built handler objects into the callables the
|
||||
registrar binds. Each handler is responsible for a narrow concern and enforces a
|
||||
set of invariants before returning:
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. |
|
||||
| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. |
|
||||
| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. |
|
||||
| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. |
|
||||
| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. |
|
||||
| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
The dedicated services encapsulate long-running work so handlers stay thin.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. |
|
||||
| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. |
|
||||
| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call
|
||||
back into the recipe scanner to mutate the in-memory cache and fingerprint
|
||||
index before returning a response. Tests assert that these methods are invoked
|
||||
even when stubbing persistence.
|
||||
* **Fingerprint management** – `RecipePersistenceService` recomputes
|
||||
fingerprints whenever LoRA metadata changes and duplicate lookups use those
|
||||
fingerprints to group recipes. Handlers bubble the resulting IDs so clients
|
||||
can merge duplicates without an extra fetch.
|
||||
* **Metadata synchronisation** – Saving or reconnecting a recipe updates the
|
||||
JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the
|
||||
scanner to resort its cache. Sharing relies on this metadata to generate
|
||||
filenames and ensure downloads stay in sync with on-disk state.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name.
|
||||
2. Implement the coroutine on an existing handler or introduce a new handler
|
||||
class inside `py/routes/handlers/recipe_handlers.py` when the concern does
|
||||
not fit existing ones.
|
||||
3. Wire additional collaborators inside
|
||||
`BaseRecipeRoutes._create_handler_set` (inject new services or factories) and
|
||||
expose helper getters on the handler owner if the handler needs to share
|
||||
utilities.
|
||||
|
||||
Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing,
|
||||
mutation, analysis-error, and sharing paths end-to-end, ensuring the controller
|
||||
and handler wiring remains valid as new capabilities are added.
|
||||
|
||||
46
docs/custom_priority_tags_format.md
Normal file
46
docs/custom_priority_tags_format.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Custom Priority Tag Format Proposal
|
||||
|
||||
To support user-defined priority tags with flexible aliasing across different model types, the configuration will be stored as editable strings. The format balances readability with enough structure for parsing on both the backend and frontend.
|
||||
|
||||
## Format Overview
|
||||
|
||||
- Each model type is declared on its own line: `model_type: entries`.
|
||||
- Entries are comma-separated and ordered by priority from highest to lowest.
|
||||
- An entry may be a single canonical tag (e.g., `realistic`) or a canonical tag with aliases.
|
||||
- Canonical tags define the final folder name that should be used when matching that entry.
|
||||
- Aliases are enclosed in parentheses and separated by `|` (vertical bar).
|
||||
- All matching is case-insensitive; stored canonical names preserve the user-specified casing for folder creation and UI suggestions.
|
||||
|
||||
### Grammar
|
||||
|
||||
```
|
||||
priority-config := model-config { "\n" model-config }
|
||||
model-config := model-type ":" entry-list
|
||||
model-type := <identifier without spaces>
|
||||
entry-list := entry { "," entry }
|
||||
entry := canonical [ "(" alias { "|" alias } ")" ]
|
||||
canonical := <tag text without parentheses or commas>
|
||||
alias := <tag text without parentheses, commas, or pipes>
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
lora: celebrity(celeb|celebrity), stylized, character(char)
|
||||
checkpoint: realistic(realism|realistic), anime(anime-style|toon)
|
||||
embedding: face, celeb(celebrity|celeb)
|
||||
```
|
||||
|
||||
## Parsing Notes
|
||||
|
||||
- Whitespace around separators is ignored to make manual editing more forgiving.
|
||||
- Duplicate canonical tags within the same model type collapse to a single entry; the first definition wins.
|
||||
- Aliases map to their canonical tag. When generating folder names, the canonical form is used.
|
||||
- Tags that do not match any alias or canonical entry fall back to the first tag in the model's tag list, preserving current behavior.
|
||||
|
||||
## Usage
|
||||
|
||||
- **Backend:** Convert each model type's string into an ordered list of canonical tags with alias sets. During path generation, iterate by priority order and match tags against both canonical names and their aliases.
|
||||
- **Frontend:** Surface canonical tags as suggestions, optionally displaying aliases in tooltips or secondary text. Input validation should warn about duplicate aliases within the same model type.
|
||||
|
||||
This format allows users to customize priority tag handling per model type while keeping editing simple and avoiding proliferation of folder names through alias normalization.
|
||||
51
docs/frontend-dom-fixtures.md
Normal file
51
docs/frontend-dom-fixtures.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Frontend DOM Fixture Strategy
|
||||
|
||||
This guide outlines how to reproduce the markup emitted by the Django templates while running Vitest in jsdom. The aim is to make it straightforward to write integration-style unit tests for managers and UI helpers without having to duplicate template fragments inline.
|
||||
|
||||
## Loading Template Markup
|
||||
|
||||
Vitest executes inside Node, so we can read the same HTML templates that ship with the extension:
|
||||
|
||||
1. Use the helper utilities from `tests/frontend/utils/domFixtures.js` to read files under the `templates/` directory.
|
||||
2. Mount the returned markup into `document.body` (or any custom container) before importing the module under test so its query selectors resolve correctly.
|
||||
|
||||
```js
|
||||
import { renderTemplate } from '../utils/domFixtures.js'; // adjust the relative path to your spec
|
||||
|
||||
beforeEach(() => {
|
||||
renderTemplate('loras.html', {
|
||||
dataset: { page: 'loras' }
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
The helper ensures the dataset is applied to the container, which mirrors how Django sets `data-page` in production.
|
||||
|
||||
## Working with Partial Components
|
||||
|
||||
Many features are implemented as template partials located under `templates/components/`. When a test only needs a fragment (for example, the progress panel or context menu markup), load the component file directly:
|
||||
|
||||
```js
|
||||
const container = renderTemplate('components/progress_panel.html');
|
||||
|
||||
const progressPanel = container.querySelector('#progress-panel');
|
||||
```
|
||||
|
||||
This pattern avoids hand-written fixture strings and keeps the tests aligned with the actual markup.
|
||||
|
||||
## Resetting Between Tests
|
||||
|
||||
The shared Vitest setup clears `document.body` and storage APIs before each test. If a suite adds additional DOM nodes outside of the body or needs to reset custom attributes mid-test, use `resetDom()` exported from `domFixtures.js`.
|
||||
|
||||
```js
|
||||
import { resetDom } from '../utils/domFixtures.js';
|
||||
|
||||
afterEach(() => {
|
||||
resetDom();
|
||||
});
|
||||
```
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Provide typed helpers for injecting mock script tags (e.g., replicating ComfyUI globals).
|
||||
- Compose higher-level fixtures that mimic specific pages (loras, checkpoints, recipes) once those managers receive dedicated suites.
|
||||
44
docs/frontend-filtering-test-matrix.md
Normal file
44
docs/frontend-filtering-test-matrix.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# LoRA & Checkpoints Filtering/Sorting Test Matrix
|
||||
|
||||
This matrix captures the scenarios that Phase 3 frontend tests should cover for the LoRA and Checkpoint managers. It focuses on how search, filter, sort, and duplicate badge toggles interact so future specs can share fixtures and expectations.
|
||||
|
||||
## Scope
|
||||
|
||||
- **Components**: `PageControls`, `FilterManager`, `SearchManager`, and `ModelDuplicatesManager` wiring invoked through `CheckpointsPageManager` and `LorasPageManager`.
|
||||
- **Templates**: `templates/loras.html` and `templates/checkpoints.html` along with shared filter panel and toolbar partials.
|
||||
- **APIs**: Requests issued through `baseModelApi.fetchModels` (via `resetAndReload`/`refreshModels`) and duplicates badge updates.
|
||||
|
||||
## Shared Setup Considerations
|
||||
|
||||
1. Render full page templates using `renderLorasPage` / `renderCheckpointsPage` helpers before importing modules so DOM queries resolve.
|
||||
2. Stub storage helpers (`getStorageItem`, `setStorageItem`, `getSessionItem`, `setSessionItem`) to observe persistence behavior without mutating real storage.
|
||||
3. Mock `sidebarManager` to capture refresh calls triggered after sort/filter actions.
|
||||
4. Provide fake API implementations exposing `resetAndReload`, `refreshModels`, `fetchFromCivitai`, `toggleBulkMode`, and `clearCustomFilter` so control events remain asynchronous but deterministic.
|
||||
5. Supply a minimal `ModelDuplicatesManager` mock exposing `toggleDuplicateMode`, `checkDuplicatesCount`, and `updateDuplicatesBadgeAfterRefresh` to validate duplicate badge wiring.
|
||||
|
||||
## Scenario Matrix
|
||||
|
||||
| ID | Feature | Scenario | LoRAs Expectations | Checkpoints Expectations | Notes |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| F-01 | Search filter | Typing a query updates `pageState.filters.search`, persists to session, and triggers `resetAndReload` on submit | Validate `SearchManager` writes query and reloads via API stub; confirm LoRA cards pass query downstream | Same as LoRAs | Cover `enter` press and clicking search icon |
|
||||
| F-02 | Tag filter | Selecting a tag chip cycles include ➜ exclude ➜ clear, updates storage, and reloads results | Tag state stored under `filters.tags[tagName] = 'include'|'exclude'`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path |
|
||||
| F-03 | Base model filter | Toggling base model checkboxes updates `filters.baseModel`, persists, and reloads | Ensure only LoRA-supported models show; toggle multi-select | Ensure SDXL/Flux base models appear as expected | Capture UI state restored from storage on next init |
|
||||
| F-04 | Favorites-only | Clicking favorites toggle updates session flag and calls `resetAndReload(true)` | Button gains `.active` class and API called | Same | Verify duplicates badge refresh when active |
|
||||
| F-05 | Sort selection | Changing sort select saves preference (legacy + new format) and reloads | Confirm `PageControls.saveSortPreference` invoked with option and API called | Same with checkpoints-specific defaults | Cover `convertLegacySortFormat` branch |
|
||||
| F-06 | Filter persistence | Re-initializing manager loads stored filters/sort and updates DOM | Filters pre-populate chips/checkboxes; favorites state restored | Same | Requires simulating repeated construction |
|
||||
| F-07 | Combined filters | Applying search + tag + base model yields aggregated query params for fetch | Assert API receives merged filter payload | Same | Validate toast messaging for active filters |
|
||||
| F-08 | Clearing filters | Using "Clear filters" resets state, storage, and reloads list | `FilterManager.clearFilters` empties `filters`, removes active class, shows toast | Same | Ensure favorites-only toggle unaffected |
|
||||
| F-09 | Duplicate badge toggle | Pressing "Find duplicates" toggles duplicate mode and updates badge counts post-refresh | `ModelDuplicatesManager.toggleDuplicateMode` invoked and badge refresh called after API rebuild | Same plus checkpoint-specific duplicate badge dataset | Connects to future duplicate-specific specs |
|
||||
| F-10 | Bulk actions menu | Opening bulk dropdown keeps filters intact and closes on outside click | Validate dropdown class toggling and no unintended reload | Same | Guard against regression when dropdown interacts with filters |
|
||||
|
||||
## Automation Coverage Status
|
||||
|
||||
- ✅ F-01 Search filter, F-02 Tag filter, F-03 Base model filter, F-04 Favorites-only toggle, F-05 Sort selection, and F-09 Duplicate badge toggle are covered by `tests/frontend/components/pageControls.filtering.test.js` for both LoRA and checkpoint pages.
|
||||
- ⏳ F-06 Filter persistence, F-07 Combined filters, F-08 Clearing filters, and F-10 Bulk actions remain to be automated alongside upcoming bulk mode refinements.
|
||||
|
||||
## Coverage Gaps & Follow-Ups
|
||||
|
||||
- Write Vitest suites that exercise the matrix for both managers, sharing fixtures through page helpers to avoid duplication.
|
||||
- Capture API parameter assertions by inspecting `baseModelApi.fetchModels` mocks rather than relying solely on state mutations.
|
||||
- Add regression cases for legacy storage migrations (old filter keys) once fixtures exist for older payloads.
|
||||
- Extend duplicate badge coverage with scenarios where `checkDuplicatesCount` signals zero duplicates versus pending calculations.
|
||||
33
docs/frontend-testing-roadmap.md
Normal file
33
docs/frontend-testing-roadmap.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Frontend Automation Testing Roadmap
|
||||
|
||||
This roadmap tracks the planned rollout of automated testing for the ComfyUI LoRA Manager frontend. Each phase builds on the infrastructure introduced in this change set and records progress so future contributors can quickly identify the next tasks.
|
||||
|
||||
## Phase Overview
|
||||
|
||||
| Phase | Goal | Primary Focus | Status | Notes |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| Phase 0 | Establish baseline tooling | Add Node test runner, jsdom environment, and seed smoke tests | ✅ Complete | Vitest + jsdom configured, example state tests committed |
|
||||
| Phase 1 | Cover state management logic | Unit test selectors, derived data helpers, and storage utilities under `static/js/state` and `static/js/utils` | ✅ Complete | Storage helpers and state selectors now exercised via deterministic suites |
|
||||
| Phase 2 | Test AppCore orchestration | Simulate page bootstrapping, infinite scroll hooks, and manager registration using JSDOM DOM fixtures | ✅ Complete | AppCore initialization + page feature suites now validate manager wiring, infinite scroll hooks, and onboarding gating |
|
||||
| Phase 3 | Validate page-specific managers | Add focused suites for `loras`, `checkpoints`, `embeddings`, and `recipes` managers covering filtering, sorting, and bulk actions | ✅ Complete | LoRA/checkpoint suites expanded; embeddings + recipes managers now covered with initialization, filtering, and duplicate workflows |
|
||||
| Phase 4 | Interaction-level regression tests | Exercise template fragments, modals, and menus to ensure UI wiring remains intact | ✅ Complete | Vitest DOM suites cover NSFW selector, recipe modal editing, and global context menus |
|
||||
| Phase 5 | Continuous integration & coverage | Integrate frontend tests into CI workflow and track coverage metrics | ✅ Complete | CI workflow runs Vitest and aggregates V8 coverage into `coverage/frontend` via a dedicated script |
|
||||
|
||||
## Next Steps Checklist
|
||||
|
||||
- [x] Expand unit tests for `storageHelpers` covering migrations and namespace behavior.
|
||||
- [x] Document DOM fixture strategy for reproducing template structures in tests.
|
||||
- [x] Prototype AppCore initialization test that verifies manager bootstrapping with stubbed dependencies.
|
||||
- [x] Add AppCore page feature suite exercising context menu creation and infinite scroll registration via DOM fixtures.
|
||||
- [x] Extend AppCore orchestration tests to cover manager wiring, bulk menu setup, and onboarding gating scenarios.
|
||||
- [x] Add interaction regression suites for context menus and recipe modals to complete Phase 4.
|
||||
- [x] Evaluate integrating coverage reporting once test surface grows (> 20 specs).
|
||||
- [x] Create shared fixtures for the loras and checkpoints pages once dedicated manager suites are added.
|
||||
- [x] Draft focused test matrix for loras/checkpoints manager filtering and sorting paths ahead of Phase 3.
|
||||
- [x] Implement LoRAs manager filtering/sorting specs for scenarios F-01–F-05 & F-09; queue remaining edge cases after duplicate/bulk flows stabilize.
|
||||
- [x] Implement checkpoints manager filtering/sorting specs for scenarios F-01–F-05 & F-09; cover remaining paths alongside bulk action work.
|
||||
- [x] Implement checkpoints page manager smoke tests covering initialization and duplicate badge wiring.
|
||||
- [x] Outline focused checkpoints scenarios (filtering, sorting, duplicate badge toggles) to feed into the shared test matrix.
|
||||
- [ ] Add duplicate badge regression coverage for zero/pending states after API refreshes.
|
||||
|
||||
Maintaining this roadmap alongside code changes will make it easier to append new automated test tasks and update their progress.
|
||||
28
docs/library-switching.md
Normal file
28
docs/library-switching.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Library Switching and Preview Routes
|
||||
|
||||
Library switching no longer requires restarting the backend. The preview
|
||||
thumbnails shown in the UI are now served through a dynamic endpoint that
|
||||
resolves files against the folders registered for the active library at request
|
||||
time. This allows the multi-library flow to update model roots without touching
|
||||
the aiohttp router, so previews remain available immediately after a switch.
|
||||
|
||||
## How the dynamic preview endpoint works
|
||||
|
||||
* `config.get_preview_static_url()` now returns `/api/lm/previews?path=<encoded>`
|
||||
for any preview path. The raw filesystem location is URL encoded so that it
|
||||
can be passed through the query string without leaking directory structure in
|
||||
the route itself.【F:py/config.py†L398-L404】
|
||||
* `PreviewRoutes` exposes the `/api/lm/previews` handler which validates the
|
||||
decoded path against the directories registered for the current library. The
|
||||
request is rejected if it falls outside those roots or if the file does not
|
||||
exist.【F:py/routes/preview_routes.py†L5-L21】【F:py/routes/handlers/preview_handlers.py†L9-L48】
|
||||
* `Config` keeps an up-to-date cache of allowed preview roots. Every time a
|
||||
library is applied the cache is rebuilt using the declared LoRA, checkpoint
|
||||
and embedding directories (including symlink targets). The validation logic
|
||||
checks preview requests against this cache.【F:py/config.py†L51-L68】【F:py/config.py†L180-L248】【F:py/config.py†L332-L346】
|
||||
|
||||
Both the ComfyUI runtime (`LoraManager.add_routes`) and the standalone launcher
|
||||
(`StandaloneLoraManager.add_routes`) register the new preview routes instead of
|
||||
mounting a static directory per root. Switching libraries therefore works
|
||||
without restarting the application, and preview URLs generated before or after a
|
||||
switch continue to resolve correctly.【F:py/lora_manager.py†L21-L82】【F:standalone.py†L302-L315】
|
||||
71
docs/priority_tags_help.md
Normal file
71
docs/priority_tags_help.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Priority Tags Configuration Guide
|
||||
|
||||
This guide explains how to tailor the tag priority order that powers folder naming and tag suggestions in the LoRA Manager. You only need to edit the comma-separated list of entries shown in the **Priority Tags** field for each model type.
|
||||
|
||||
## 1. Pick the Model Type
|
||||
|
||||
In the **Priority Tags** dialog you will find one tab per model type (LoRA, Checkpoint, Embedding). Select the tab you want to update; changes on one tab do not affect the others.
|
||||
|
||||
## 2. Edit the Entry List
|
||||
|
||||
Inside the textarea you will see a line similar to:
|
||||
|
||||
```
|
||||
character, concept, style(toon|toon_style)
|
||||
```
|
||||
|
||||
This entire line is the **entry list**. Replace it with your own ordered list.
|
||||
|
||||
### Entry Rules
|
||||
|
||||
Each entry is separated by a comma, in order from highest to lowest priority:
|
||||
|
||||
- **Canonical tag only:** `realistic`
|
||||
- **Canonical tag with aliases:** `character(char|chars)`
|
||||
|
||||
Aliases live inside `()` and are separated with `|`. The canonical name is what appears in folder names and UI suggestions when any of the aliases are detected. Matching is case-insensitive.
|
||||
|
||||
## Use `{first_tag}` in Path Templates
|
||||
|
||||
When your path template contains `{first_tag}`, the app picks a folder name based on your priority list and the model’s own tags:
|
||||
|
||||
- It checks the priority list from top to bottom. If a canonical tag or any of its aliases appear in the model tags, that canonical name becomes the folder name.
|
||||
- If no priority tags are found but the model has tags, the very first model tag is used.
|
||||
- If the model has no tags at all, the folder falls back to `no tags`.
|
||||
|
||||
### Example
|
||||
|
||||
With a template like `/{model_type}/{first_tag}` and the priority entry list `character(char|chars), style(anime|toon)`:
|
||||
|
||||
| Model Tags | Folder Name | Why |
|
||||
| --- | --- | --- |
|
||||
| `["chars", "female"]` | `character` | `chars` matches the `character` alias, so the canonical wins. |
|
||||
| `["anime", "portrait"]` | `style` | `anime` hits the `style` entry, so its canonical label is used. |
|
||||
| `["portrait", "bw"]` | `portrait` | No priority match, so the first model tag is used. |
|
||||
| `[]` | `no tags` | Nothing to match, so the fallback is applied. |
|
||||
|
||||
## 3. Save the Settings
|
||||
|
||||
After editing the entry list, press **Enter** to save. Use **Shift+Enter** whenever you need a new line. Clicking outside the field also saves automatically. A success toast confirms the update.
|
||||
|
||||
## Examples
|
||||
|
||||
| Goal | Entry List |
|
||||
| --- | --- |
|
||||
| Prefer people over styles | `character, portraits, style(anime\|toon)` |
|
||||
| Group sci-fi variants | `sci-fi(scifi\|science_fiction), cyberpunk(cyber\|punk)` |
|
||||
| Alias shorthand tags | `realistic(real\|realisim), photorealistic(photo_real)` |
|
||||
|
||||
## Tips
|
||||
|
||||
- Keep canonical names short and meaningful—they become folder names.
|
||||
- Place the most important categories first; the first match wins.
|
||||
- Avoid duplicate canonical names within the same list; only the first instance is used.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Unexpected folder name?** Check that the canonical name you want is placed before other matches.
|
||||
- **Alias not working?** Ensure the alias is inside parentheses and separated with `|`, e.g. `character(char|chars)`.
|
||||
- **Validation error?** Look for missing parentheses or stray commas. Each entry must follow the `canonical(alias|alias)` pattern or just `canonical`.
|
||||
|
||||
With these basics you can quickly adapt Priority Tags to match your library’s organization style.
|
||||
26
docs/testing/coverage_analysis.md
Normal file
26
docs/testing/coverage_analysis.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Backend Test Coverage Notes
|
||||
|
||||
## Pytest Execution
|
||||
- Command: `python -m pytest`
|
||||
- Result: All 283 collected tests passed in the current environment.
|
||||
- Coverage tooling (``pytest-cov``/``coverage``) is unavailable in the offline sandbox, so line-level metrics could not be generated. The earlier attempt to install ``pytest-cov`` failed because the package index cannot be reached from the container.
|
||||
|
||||
## High-Priority Gaps to Address
|
||||
|
||||
### 1. Standalone server bootstrapping
|
||||
* **Source:** [`standalone.py`](../../standalone.py)
|
||||
* **Why it matters:** The standalone entry point wires together the aiohttp application, static asset routes, model-route registration, and configuration validation. None of these behaviours are covered by automated tests, leaving regressions in bootstrapping logic undetected.
|
||||
* **Suggested coverage:** Add integration-style tests that instantiate `StandaloneServer`/`StandaloneLoraManager` with temporary settings and assert that routes (HTTP + websocket) are registered, configuration warnings fire for missing paths, and the mock ComfyUI shims behave as expected.
|
||||
|
||||
### 2. Model service registration factory
|
||||
* **Source:** [`py/services/model_service_factory.py`](../../py/services/model_service_factory.py)
|
||||
* **Why it matters:** The factory coordinates which model services and routes the API exposes, including error handling when unknown model types are requested. No current tests verify registration, memoization of route instances, or the logging path on failures.
|
||||
* **Suggested coverage:** Unit tests that exercise `register_model_type`, `get_route_instance`, error branches in `get_service_class`/`get_route_class`, and `setup_all_routes` when a route setup raises. Use lightweight fakes to confirm the logger is called and state is cleared via `clear_registrations`.
|
||||
|
||||
### 3. Server-side i18n helper
|
||||
* **Source:** [`py/services/server_i18n.py`](../../py/services/server_i18n.py)
|
||||
* **Why it matters:** Template rendering relies on the `ServerI18nManager` to load locale JSON, perform key lookups, and format parameters. The fallback logic (dot-notation lookup, English fallbacks, placeholder substitution) is untested, so malformed locale files or regressions in placeholder handling would slip through.
|
||||
* **Suggested coverage:** Tests that load fixture locale dictionaries, assert `set_locale` fallbacks, verify nested key resolution and placeholder substitution, and ensure missing keys return the original identifier.
|
||||
|
||||
## Next Steps
|
||||
Prioritize creating focused unit tests around these modules, then re-run pytest once coverage tooling is available to confirm the new tests close the identified gaps.
|
||||
BIN
example_workflows/Illustrious Pony Example.jpg
Normal file
BIN
example_workflows/Illustrious Pony Example.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 668 KiB |
1
example_workflows/Illustrious Pony Example.json
Normal file
1
example_workflows/Illustrious Pony Example.json
Normal file
File diff suppressed because one or more lines are too long
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1474
locales/de.json
Normal file
1474
locales/de.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/en.json
Normal file
1474
locales/en.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/es.json
Normal file
1474
locales/es.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/fr.json
Normal file
1474
locales/fr.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/he.json
Normal file
1474
locales/he.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/ja.json
Normal file
1474
locales/ja.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/ko.json
Normal file
1474
locales/ko.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/ru.json
Normal file
1474
locales/ru.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/zh-CN.json
Normal file
1474
locales/zh-CN.json
Normal file
File diff suppressed because it is too large
Load Diff
1474
locales/zh-TW.json
Normal file
1474
locales/zh-TW.json
Normal file
File diff suppressed because it is too large
Load Diff
2575
package-lock.json
generated
Normal file
2575
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
15
package.json
Normal file
15
package.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "comfyui-lora-manager-frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "node scripts/run_frontend_coverage.js"
|
||||
},
|
||||
"devDependencies": {
|
||||
"jsdom": "^24.0.0",
|
||||
"vitest": "^1.6.0"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Project namespace package."""
|
||||
|
||||
# pytest's internal compatibility layer still imports ``py.path.local`` from the
|
||||
# historical ``py`` dependency. Because this project reuses the ``py`` package
|
||||
# name, we expose a minimal shim so ``py.path.local`` resolves to ``pathlib.Path``
|
||||
# during test runs without pulling in the external dependency.
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
path = SimpleNamespace(local=Path)
|
||||
|
||||
__all__ = ["path"]
|
||||
|
||||
510
py/config.py
510
py/config.py
@@ -1,26 +1,207 @@
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set
|
||||
import logging
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
from .utils.settings_paths import ensure_settings_file, load_settings_template
|
||||
|
||||
# Use an environment variable to control standalone mode
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_folder_paths_for_comparison(
|
||||
folder_paths: Mapping[str, Iterable[str]]
|
||||
) -> Dict[str, Set[str]]:
|
||||
"""Normalize folder paths for comparison across libraries."""
|
||||
|
||||
normalized: Dict[str, Set[str]] = {}
|
||||
for key, values in folder_paths.items():
|
||||
if isinstance(values, str):
|
||||
candidate_values: Iterable[str] = [values]
|
||||
else:
|
||||
try:
|
||||
candidate_values = iter(values)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
normalized_values: Set[str] = set()
|
||||
for value in candidate_values:
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
normalized_values.add(os.path.normcase(os.path.normpath(stripped)))
|
||||
|
||||
if normalized_values:
|
||||
normalized[key] = normalized_values
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_library_folder_paths(
|
||||
library_payload: Mapping[str, Any]
|
||||
) -> Dict[str, Set[str]]:
|
||||
"""Return normalized folder paths extracted from a library payload."""
|
||||
|
||||
folder_paths = library_payload.get("folder_paths")
|
||||
if isinstance(folder_paths, Mapping):
|
||||
return _normalize_folder_paths_for_comparison(folder_paths)
|
||||
return {}
|
||||
|
||||
|
||||
def _get_template_folder_paths() -> Dict[str, Set[str]]:
|
||||
"""Return normalized folder paths defined in the bundled template."""
|
||||
|
||||
template_payload = load_settings_template()
|
||||
if not template_payload:
|
||||
return {}
|
||||
|
||||
folder_paths = template_payload.get("folder_paths")
|
||||
if isinstance(folder_paths, Mapping):
|
||||
return _normalize_folder_paths_for_comparison(folder_paths)
|
||||
return {}
|
||||
|
||||
|
||||
class Config:
|
||||
"""Global configuration for LoRA Manager"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
||||
# 路径映射字典, target to link mapping
|
||||
self._path_mappings = {}
|
||||
# 静态路由映射字典, target to route mapping
|
||||
self._route_mappings = {}
|
||||
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
|
||||
# Path mapping dictionary, target to link mapping
|
||||
self._path_mappings: Dict[str, str] = {}
|
||||
# Normalized preview root directories used to validate preview access
|
||||
self._preview_root_paths: Set[Path] = set()
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.checkpoints_roots = self._init_checkpoint_paths()
|
||||
self.temp_directory = folder_paths.get_temp_directory()
|
||||
# 在初始化时扫描符号链接
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
self.embeddings_roots = None
|
||||
self.base_models_roots = self._init_checkpoint_paths()
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._scan_symbolic_links()
|
||||
self._rebuild_preview_roots()
|
||||
|
||||
if not standalone_mode:
|
||||
# Save the paths to settings.json when running in ComfyUI mode
|
||||
self.save_folder_paths_to_settings()
|
||||
|
||||
def save_folder_paths_to_settings(self):
|
||||
"""Persist ComfyUI-derived folder paths to the multi-library settings."""
|
||||
try:
|
||||
ensure_settings_file(logger)
|
||||
from .services.settings_manager import get_settings_manager
|
||||
|
||||
settings_service = get_settings_manager()
|
||||
libraries = settings_service.get_libraries()
|
||||
comfy_library = libraries.get("comfyui", {})
|
||||
default_library = libraries.get("default", {})
|
||||
|
||||
template_folder_paths = _get_template_folder_paths()
|
||||
default_library_paths: Dict[str, Set[str]] = {}
|
||||
if isinstance(default_library, Mapping):
|
||||
default_library_paths = _normalize_library_folder_paths(default_library)
|
||||
|
||||
libraries_changed = False
|
||||
if (
|
||||
isinstance(default_library, Mapping)
|
||||
and template_folder_paths
|
||||
and default_library_paths == template_folder_paths
|
||||
):
|
||||
if "comfyui" in libraries:
|
||||
try:
|
||||
settings_service.delete_library("default")
|
||||
libraries_changed = True
|
||||
logger.info("Removed template 'default' library entry")
|
||||
except Exception as delete_error:
|
||||
logger.debug(
|
||||
"Failed to delete template 'default' library: %s",
|
||||
delete_error,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
settings_service.rename_library("default", "comfyui")
|
||||
libraries_changed = True
|
||||
logger.info("Renamed template 'default' library to 'comfyui'")
|
||||
except Exception as rename_error:
|
||||
logger.debug(
|
||||
"Failed to rename template 'default' library: %s",
|
||||
rename_error,
|
||||
)
|
||||
|
||||
if libraries_changed:
|
||||
libraries = settings_service.get_libraries()
|
||||
comfy_library = libraries.get("comfyui", {})
|
||||
default_library = libraries.get("default", {})
|
||||
|
||||
target_folder_paths = {
|
||||
'loras': list(self.loras_roots),
|
||||
'checkpoints': list(self.checkpoints_roots or []),
|
||||
'unet': list(self.unet_roots or []),
|
||||
'embeddings': list(self.embeddings_roots or []),
|
||||
}
|
||||
|
||||
normalized_target_paths = _normalize_folder_paths_for_comparison(target_folder_paths)
|
||||
|
||||
normalized_default_paths: Optional[Dict[str, Set[str]]] = None
|
||||
if isinstance(default_library, Mapping):
|
||||
normalized_default_paths = _normalize_library_folder_paths(default_library)
|
||||
|
||||
if (
|
||||
not comfy_library
|
||||
and default_library
|
||||
and normalized_target_paths
|
||||
and normalized_default_paths == normalized_target_paths
|
||||
):
|
||||
try:
|
||||
settings_service.rename_library("default", "comfyui")
|
||||
logger.info("Renamed legacy 'default' library to 'comfyui'")
|
||||
libraries = settings_service.get_libraries()
|
||||
comfy_library = libraries.get("comfyui", {})
|
||||
except Exception as rename_error:
|
||||
logger.debug(
|
||||
"Failed to rename legacy 'default' library: %s", rename_error
|
||||
)
|
||||
|
||||
default_lora_root = comfy_library.get("default_lora_root", "")
|
||||
if not default_lora_root and len(self.loras_roots) == 1:
|
||||
default_lora_root = self.loras_roots[0]
|
||||
|
||||
default_checkpoint_root = comfy_library.get("default_checkpoint_root", "")
|
||||
if (not default_checkpoint_root and self.checkpoints_roots and
|
||||
len(self.checkpoints_roots) == 1):
|
||||
default_checkpoint_root = self.checkpoints_roots[0]
|
||||
|
||||
default_embedding_root = comfy_library.get("default_embedding_root", "")
|
||||
if (not default_embedding_root and self.embeddings_roots and
|
||||
len(self.embeddings_roots) == 1):
|
||||
default_embedding_root = self.embeddings_roots[0]
|
||||
|
||||
metadata = dict(comfy_library.get("metadata", {}))
|
||||
metadata.setdefault("display_name", "ComfyUI")
|
||||
metadata["source"] = "comfyui"
|
||||
|
||||
settings_service.upsert_library(
|
||||
"comfyui",
|
||||
folder_paths=target_folder_paths,
|
||||
default_lora_root=default_lora_root,
|
||||
default_checkpoint_root=default_checkpoint_root,
|
||||
default_embedding_root=default_embedding_root,
|
||||
metadata=metadata,
|
||||
activate=True,
|
||||
)
|
||||
|
||||
logger.info("Updated 'comfyui' library with current folder paths")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
def _is_link(self, path: str) -> bool:
|
||||
try:
|
||||
@@ -40,15 +221,18 @@ class Config:
|
||||
return False
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""扫描所有 LoRA 和 Checkpoint 根目录中的符号链接"""
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
for root in self.loras_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.checkpoints_roots:
|
||||
for root in self.base_models_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.embeddings_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
def _scan_directory_links(self, root: str):
|
||||
"""递归扫描目录中的符号链接"""
|
||||
"""Recursively scan symbolic links in a directory"""
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
@@ -63,104 +247,282 @@ class Config:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
|
||||
def add_path_mapping(self, link_path: str, target_path: str):
|
||||
"""添加符号链接路径映射
|
||||
target_path: 实际目标路径
|
||||
link_path: 符号链接路径
|
||||
"""Add a symbolic link path mapping
|
||||
target_path: actual target path
|
||||
link_path: symbolic link path
|
||||
"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||
# 保持原有的映射关系:目标路径 -> 链接路径
|
||||
# Keep the original mapping: target path -> link path
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
self._preview_root_paths.update(self._expand_preview_root(normalized_target))
|
||||
self._preview_root_paths.update(self._expand_preview_root(normalized_link))
|
||||
|
||||
def add_route_mapping(self, path: str, route: str):
|
||||
"""添加静态路由映射"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
self._route_mappings[normalized_path] = route
|
||||
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
def _expand_preview_root(self, path: str) -> Set[Path]:
|
||||
"""Return normalized ``Path`` objects representing a preview root."""
|
||||
|
||||
roots: Set[Path] = set()
|
||||
if not path:
|
||||
return roots
|
||||
|
||||
try:
|
||||
raw_path = Path(path).expanduser()
|
||||
except Exception:
|
||||
return roots
|
||||
|
||||
if raw_path.is_absolute():
|
||||
roots.add(raw_path)
|
||||
|
||||
try:
|
||||
resolved = raw_path.resolve(strict=False)
|
||||
except RuntimeError:
|
||||
resolved = raw_path.absolute()
|
||||
roots.add(resolved)
|
||||
|
||||
try:
|
||||
real_path = raw_path.resolve()
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
real_path = resolved
|
||||
roots.add(real_path)
|
||||
|
||||
normalized: Set[Path] = set()
|
||||
for candidate in roots:
|
||||
if candidate.is_absolute():
|
||||
normalized.add(candidate)
|
||||
else:
|
||||
try:
|
||||
normalized.add(candidate.resolve(strict=False))
|
||||
except RuntimeError:
|
||||
normalized.add(candidate.absolute())
|
||||
|
||||
return normalized
|
||||
|
||||
def _rebuild_preview_roots(self) -> None:
|
||||
"""Recompute the cache of directories permitted for previews."""
|
||||
|
||||
preview_roots: Set[Path] = set()
|
||||
|
||||
for root in self.loras_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.base_models_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.embeddings_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
|
||||
for target, link in self._path_mappings.items():
|
||||
preview_roots.update(self._expand_preview_root(target))
|
||||
preview_roots.update(self._expand_preview_root(link))
|
||||
|
||||
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()}
|
||||
|
||||
def map_path_to_link(self, path: str) -> str:
|
||||
"""将目标路径映射回符号链接路径"""
|
||||
"""Map a target path back to its symbolic link path"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_path.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为链接路径
|
||||
# If the path starts with the target path, replace with link path
|
||||
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return path
|
||||
|
||||
def map_link_to_path(self, link_path: str) -> str:
|
||||
"""将符号链接路径映射回实际路径"""
|
||||
"""Map a symbolic link path back to the actual path"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_link.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为实际路径
|
||||
# If the path starts with the target path, replace with actual path
|
||||
mapped_path = normalized_link.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return link_path
|
||||
|
||||
def _dedupe_existing_paths(self, raw_paths: Iterable[str]) -> Dict[str, str]:
|
||||
dedup: Dict[str, str] = {}
|
||||
for path in raw_paths:
|
||||
if not isinstance(path, str):
|
||||
continue
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
normalized = os.path.normpath(path).replace(os.sep, '/')
|
||||
if real_path not in dedup:
|
||||
dedup[real_path] = normalized
|
||||
return dedup
|
||||
|
||||
def _prepare_lora_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
||||
path_map = self._dedupe_existing_paths(raw_paths)
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
|
||||
def _prepare_checkpoint_paths(
|
||||
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
|
||||
) -> List[str]:
|
||||
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
|
||||
unet_map = self._dedupe_existing_paths(unet_paths)
|
||||
|
||||
merged_map: Dict[str, str] = {}
|
||||
for real_path, original in {**checkpoint_map, **unet_map}.items():
|
||||
if real_path not in merged_map:
|
||||
merged_map[real_path] = original
|
||||
|
||||
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
|
||||
|
||||
checkpoint_values = set(checkpoint_map.values())
|
||||
unet_values = set(unet_map.values())
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
|
||||
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
||||
path_map = self._dedupe_existing_paths(raw_paths)
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
|
||||
def _apply_library_paths(self, folder_paths: Mapping[str, Iterable[str]]) -> None:
|
||||
self._path_mappings.clear()
|
||||
self._preview_root_paths = set()
|
||||
|
||||
lora_paths = folder_paths.get('loras', []) or []
|
||||
checkpoint_paths = folder_paths.get('checkpoints', []) or []
|
||||
unet_paths = folder_paths.get('unet', []) or []
|
||||
embedding_paths = folder_paths.get('embeddings', []) or []
|
||||
|
||||
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
||||
self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
|
||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||
|
||||
self._scan_symbolic_links()
|
||||
self._rebuild_preview_roots()
|
||||
|
||||
def _init_lora_paths(self) -> List[str]:
|
||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||
paths = sorted(set(path.replace(os.sep, "/")
|
||||
for path in folder_paths.get_folder_paths("loras")
|
||||
if os.path.exists(path)), key=lambda p: p.lower())
|
||||
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
|
||||
|
||||
if not paths:
|
||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||
|
||||
# 初始化路径映射
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
|
||||
return paths
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("loras")
|
||||
unique_paths = self._prepare_lora_paths(raw_paths)
|
||||
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid loras folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing LoRA paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_checkpoint_paths(self) -> List[str]:
|
||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||
# Get checkpoint paths from folder_paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Combine all checkpoint-related paths
|
||||
all_paths = checkpoint_paths + diffusion_paths + unet_paths
|
||||
|
||||
# Filter and normalize paths
|
||||
paths = sorted(set(path.replace(os.sep, "/")
|
||||
for path in all_paths
|
||||
if os.path.exists(path)), key=lambda p: p.lower())
|
||||
|
||||
print("Found checkpoint roots:", paths)
|
||||
|
||||
if not paths:
|
||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||
try:
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing checkpoint paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_embedding_paths(self) -> List[str]:
|
||||
"""Initialize and validate embedding paths from ComfyUI settings"""
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("embeddings")
|
||||
unique_paths = self._prepare_embedding_paths(raw_paths)
|
||||
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid embeddings folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing embedding paths: {e}")
|
||||
return []
|
||||
|
||||
# 初始化路径映射,与 LoRA 路径处理方式相同
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
|
||||
return paths
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
"""Convert local preview path to static URL"""
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
real_path = os.path.realpath(preview_path).replace(os.sep, '/')
|
||||
|
||||
for path, route in self._route_mappings.items():
|
||||
if real_path.startswith(path):
|
||||
relative_path = os.path.relpath(real_path, path)
|
||||
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||
normalized = os.path.normpath(preview_path).replace(os.sep, '/')
|
||||
encoded_path = urllib.parse.quote(normalized, safe='')
|
||||
return f'/api/lm/previews?path={encoded_path}'
|
||||
|
||||
return ""
|
||||
def is_preview_path_allowed(self, preview_path: str) -> bool:
|
||||
"""Return ``True`` if ``preview_path`` is within an allowed directory."""
|
||||
|
||||
if not preview_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
candidate = Path(preview_path).expanduser().resolve(strict=False)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
for root in self._preview_root_paths:
|
||||
try:
|
||||
candidate.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
|
||||
"""Update runtime paths to match the provided library configuration."""
|
||||
folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {}
|
||||
if not isinstance(folder_paths, Mapping):
|
||||
folder_paths = {}
|
||||
|
||||
self._apply_library_paths(folder_paths)
|
||||
|
||||
logger.info(
|
||||
"Applied library settings with %d lora roots, %d checkpoint roots, and %d embedding roots",
|
||||
len(self.loras_roots or []),
|
||||
len(self.base_models_roots or []),
|
||||
len(self.embeddings_roots or []),
|
||||
)
|
||||
|
||||
def get_library_registry_snapshot(self) -> Dict[str, object]:
|
||||
"""Return the current library registry and active library name."""
|
||||
|
||||
try:
|
||||
from .services.settings_manager import get_settings_manager
|
||||
|
||||
settings_service = get_settings_manager()
|
||||
libraries = settings_service.get_libraries()
|
||||
active_library = settings_service.get_active_library_name()
|
||||
return {
|
||||
"active_library": active_library,
|
||||
"libraries": libraries,
|
||||
}
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.debug("Failed to collect library registry snapshot: %s", exc)
|
||||
return {"active_library": "", "libraries": {}}
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
|
||||
@@ -1,172 +1,354 @@
|
||||
import asyncio
|
||||
from server import PromptServer # type: ignore
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
from .routes.preview_routes import PreviewRoutes
|
||||
from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import get_settings_manager
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check if we're in standalone mode
|
||||
STANDALONE_MODE = 'nodes' not in sys.modules
|
||||
|
||||
HEADER_SIZE_LIMIT = 16384
|
||||
|
||||
|
||||
def _sanitize_size_limit(value):
|
||||
"""Return a non-negative integer size for ``handler_args`` comparisons."""
|
||||
|
||||
try:
|
||||
coerced = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
return coerced if coerced >= 0 else 0
|
||||
|
||||
|
||||
class _SettingsProxy:
|
||||
def __init__(self):
|
||||
self._manager = None
|
||||
|
||||
def _resolve(self):
|
||||
if self._manager is None:
|
||||
self._manager = get_settings_manager()
|
||||
return self._manager
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self._resolve().get(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._resolve(), item)
|
||||
|
||||
|
||||
settings = _SettingsProxy()
|
||||
|
||||
class LoraManager:
|
||||
"""Main entry point for LoRA Manager plugin"""
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
"""Initialize and register all routes"""
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static routes for each lora root
|
||||
for idx, root in enumerate(config.loras_roots, start=1):
|
||||
preview_path = f'/loras_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for symlink target paths
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
if target_path not in added_targets:
|
||||
# Determine if this is a checkpoint or lora link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
|
||||
app.router.add_static(route_path, target_path)
|
||||
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(target_path)
|
||||
|
||||
# Increase allowed header sizes so browsers with large localhost cookie
|
||||
# jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default
|
||||
# limits. Cookies for unrelated apps are still sent to the plugin and
|
||||
# may otherwise raise LineTooLong errors when the request parser reads
|
||||
# them. Preserve any previously configured handler arguments while
|
||||
# ensuring our minimum sizes are applied.
|
||||
handler_args = getattr(app, "_handler_args", {}) or {}
|
||||
updated_handler_args = dict(handler_args)
|
||||
updated_handler_args["max_field_size"] = max(
|
||||
_sanitize_size_limit(handler_args.get("max_field_size", 0)),
|
||||
HEADER_SIZE_LIMIT,
|
||||
)
|
||||
updated_handler_args["max_line_size"] = max(
|
||||
_sanitize_size_limit(handler_args.get("max_line_size", 0)),
|
||||
HEADER_SIZE_LIMIT,
|
||||
)
|
||||
app._handler_args = updated_handler_args
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
example_images_path = settings.get('example_images_path')
|
||||
logger.info(f"Example images path: {example_images_path}")
|
||||
if example_images_path and os.path.exists(example_images_path):
|
||||
app.router.add_static('/example_images_static', example_images_path)
|
||||
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}")
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
app.router.add_static('/locales', config.i18n_path)
|
||||
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
|
||||
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
# Register default model types with the factory
|
||||
register_default_model_types()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
# Setup non-model-specific routes
|
||||
stats_routes = StatsRoutes()
|
||||
stats_routes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
PreviewRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||
|
||||
@classmethod
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Get file monitors through ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_lora_monitor()
|
||||
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
|
||||
|
||||
# Start monitors
|
||||
lora_monitor.start()
|
||||
logger.info("Lora monitor started")
|
||||
|
||||
# Make sure checkpoint monitor has paths before starting
|
||||
await checkpoint_monitor.initialize_paths()
|
||||
checkpoint_monitor.start()
|
||||
logger.info("Checkpoint monitor started")
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Register DownloadManager with ServiceRegistry
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
await ServiceRegistry.get_download_manager()
|
||||
|
||||
from .services.metadata_service import initialize_metadata_providers
|
||||
await initialize_metadata_providers()
|
||||
|
||||
# Initialize WebSocket manager
|
||||
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||
await ServiceRegistry.get_websocket_manager()
|
||||
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
init_tasks = [
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
|
||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
]
|
||||
|
||||
await ExampleImagesMigration.check_and_run_migrations()
|
||||
|
||||
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
# Schedule post-initialization tasks to run after scanners complete
|
||||
asyncio.create_task(
|
||||
cls._run_post_initialization_tasks(init_tasks),
|
||||
name='post_init_tasks'
|
||||
)
|
||||
|
||||
logger.debug("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _run_post_initialization_tasks(cls, init_tasks):
|
||||
"""Run post-initialization tasks after all scanners complete"""
|
||||
try:
|
||||
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
|
||||
|
||||
# Wait for all scanner initialization tasks to complete
|
||||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||
|
||||
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
|
||||
|
||||
# Run post-initialization tasks
|
||||
post_tasks = [
|
||||
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
||||
# Add more post-initialization tasks here as needed
|
||||
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
||||
]
|
||||
|
||||
# Run all post-initialization tasks
|
||||
results = await asyncio.gather(*post_tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
for i, result in enumerate(results):
|
||||
task_name = post_tasks[i].get_name()
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
|
||||
else:
|
||||
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
|
||||
|
||||
logger.debug("LoRA Manager: All post-initialization tasks completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files(cls):
|
||||
"""Clean up .bak files in all model roots"""
|
||||
try:
|
||||
logger.debug("Starting cleanup of .bak files in model directories...")
|
||||
|
||||
# Collect all model roots
|
||||
all_roots = set()
|
||||
all_roots.update(config.loras_roots)
|
||||
all_roots.update(config.base_models_roots)
|
||||
all_roots.update(config.embeddings_roots)
|
||||
|
||||
total_deleted = 0
|
||||
total_size_freed = 0
|
||||
|
||||
for root_path in all_roots:
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
|
||||
total_deleted += deleted_count
|
||||
total_size_freed += size_freed
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
|
||||
|
||||
# Yield control periodically
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
|
||||
else:
|
||||
logger.debug("Backup cleanup completed: no .bak files found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files_in_directory(cls, directory_path: str):
|
||||
"""Clean up .bak files in a specific directory recursively
|
||||
|
||||
Args:
|
||||
directory_path: Path to the directory to clean
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (number of files deleted, total size freed in bytes)
|
||||
"""
|
||||
deleted_count = 0
|
||||
size_freed = 0
|
||||
visited_paths = set()
|
||||
|
||||
def cleanup_recursive(path):
|
||||
nonlocal deleted_count, size_freed
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
|
||||
file_size = entry.stat().st_size
|
||||
os.remove(entry.path)
|
||||
deleted_count += 1
|
||||
size_freed += file_size
|
||||
logger.debug(f"Deleted .bak file: {entry.path}")
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
cleanup_recursive(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory {path} for .bak files: {e}")
|
||||
|
||||
# Run the recursive cleanup in a thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, cleanup_recursive, directory_path)
|
||||
|
||||
return deleted_count, size_freed
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_example_images_folders(cls):
|
||||
"""Invoke the example images cleanup service for manual execution."""
|
||||
try:
|
||||
service = ExampleImagesCleanupService()
|
||||
result = await service.cleanup_example_image_folders()
|
||||
|
||||
if result.get('success'):
|
||||
logger.debug(
|
||||
"Manual example images cleanup completed: moved=%s",
|
||||
result.get('moved_total'),
|
||||
)
|
||||
elif result.get('partial_success'):
|
||||
logger.warning(
|
||||
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
|
||||
result.get('moved_total'),
|
||||
result.get('move_failures'),
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Manual example images cleanup skipped or failed: %s",
|
||||
result.get('error', 'no changes'),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e: # pragma: no cover - defensive guard
|
||||
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'error_code': 'unexpected_error',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def _cleanup(cls, app):
|
||||
"""Cleanup resources using ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Cleaning up services")
|
||||
|
||||
# Get monitors from ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
|
||||
if lora_monitor:
|
||||
lora_monitor.stop()
|
||||
logger.info("Stopped LoRA monitor")
|
||||
|
||||
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
|
||||
if checkpoint_monitor:
|
||||
checkpoint_monitor.stop()
|
||||
logger.info("Stopped checkpoint monitor")
|
||||
|
||||
# Close CivitaiClient gracefully
|
||||
civitai_client = await ServiceRegistry.get_service("civitai_client")
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
logger.info("Closed CivitaiClient connection")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
30
py/metadata_collector/__init__.py
Normal file
30
py/metadata_collector/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
if not standalone_mode:
|
||||
from .metadata_hook import MetadataHook
|
||||
from .metadata_registry import MetadataRegistry
|
||||
|
||||
def init():
|
||||
# Install hooks to collect metadata during execution
|
||||
MetadataHook.install()
|
||||
|
||||
# Initialize registry
|
||||
registry = MetadataRegistry()
|
||||
|
||||
print("ComfyUI Metadata Collector initialized")
|
||||
|
||||
def get_metadata(prompt_id=None):
|
||||
"""Helper function to get metadata from the registry"""
|
||||
registry = MetadataRegistry()
|
||||
return registry.get_metadata(prompt_id)
|
||||
else:
|
||||
# Standalone mode - provide dummy implementations
|
||||
def init():
|
||||
print("ComfyUI Metadata Collector disabled in standalone mode")
|
||||
|
||||
def get_metadata(prompt_id=None):
|
||||
"""Dummy implementation for standalone mode"""
|
||||
return {}
|
||||
13
py/metadata_collector/constants.py
Normal file
13
py/metadata_collector/constants.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Constants used by the metadata collector"""
|
||||
|
||||
# Metadata categories
|
||||
MODELS = "models"
|
||||
PROMPTS = "prompts"
|
||||
SAMPLING = "sampling"
|
||||
LORAS = "loras"
|
||||
SIZE = "size"
|
||||
IMAGES = "images"
|
||||
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
|
||||
|
||||
# Complete list of categories to track
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||
204
py/metadata_collector/metadata_hook.py
Normal file
204
py/metadata_collector/metadata_hook.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import sys
|
||||
import inspect
|
||||
from .metadata_registry import MetadataRegistry
|
||||
|
||||
class MetadataHook:
|
||||
"""Install hooks for metadata collection"""
|
||||
|
||||
@staticmethod
|
||||
def install():
|
||||
"""Install hooks to collect metadata during execution"""
|
||||
try:
|
||||
# Import ComfyUI's execution module
|
||||
execution = None
|
||||
try:
|
||||
# Try direct import first
|
||||
import execution # type: ignore
|
||||
except ImportError:
|
||||
# Try to locate from system modules
|
||||
for module_name in sys.modules:
|
||||
if module_name.endswith('.execution'):
|
||||
execution = sys.modules[module_name]
|
||||
break
|
||||
|
||||
# If we can't find the execution module, we can't install hooks
|
||||
if execution is None:
|
||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
return
|
||||
|
||||
# Detect whether we're using the new async version of ComfyUI
|
||||
is_async = False
|
||||
map_node_func_name = '_map_node_over_list'
|
||||
|
||||
if hasattr(execution, '_async_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
|
||||
map_node_func_name = '_async_map_node_over_list'
|
||||
elif hasattr(execution, '_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||
|
||||
if is_async:
|
||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||
else:
|
||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
MetadataHook._install_sync_hooks(execution)
|
||||
|
||||
print("Metadata collection hooks installed for runtime values")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error installing metadata hooks: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _install_sync_hooks(execution):
|
||||
"""Install hooks for synchronous execution model"""
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
|
||||
@staticmethod
|
||||
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
|
||||
"""Install hooks for asynchronous execution model"""
|
||||
# Store the original _async_map_node_over_list function
|
||||
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||
|
||||
# Wrapped async function, compatible with both stable and nightly
|
||||
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
|
||||
hidden_inputs = kwargs.get('hidden_inputs', None)
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Call original function with all args/kwargs
|
||||
results = await original_map_node_over_list(
|
||||
prompt_id, unique_id, obj, input_data_all, func,
|
||||
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
|
||||
)
|
||||
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return await original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions with async versions
|
||||
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||
execution.execute = async_execute_with_prompt_tracking
|
||||
479
py/metadata_collector/metadata_processor.py
Normal file
479
py/metadata_collector/metadata_processor.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import json
|
||||
import os
|
||||
from .constants import IMAGES
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
|
||||
|
||||
class MetadataProcessor:
|
||||
"""Process and format collected metadata"""
|
||||
|
||||
@staticmethod
|
||||
def find_primary_sampler(metadata, downstream_id=None):
|
||||
"""
|
||||
Find the primary KSampler node that executed before the given downstream node
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
if downstream_id is None:
|
||||
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
||||
downstream_id = metadata[IMAGES]["first_decode"]["node_id"]
|
||||
|
||||
# If we have a downstream_id and execution_order, use it to narrow down potential samplers
|
||||
if downstream_id and "execution_order" in metadata:
|
||||
execution_order = metadata["execution_order"]
|
||||
|
||||
# Find the index of the downstream node in the execution order
|
||||
if downstream_id in execution_order:
|
||||
downstream_index = execution_order.index(downstream_id)
|
||||
|
||||
# Extract all sampler nodes that executed before the downstream node
|
||||
candidate_samplers = {}
|
||||
for i in range(downstream_index):
|
||||
node_id = execution_order[i]
|
||||
# Use IS_SAMPLER flag to identify true sampler nodes
|
||||
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
|
||||
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
|
||||
|
||||
# If we found candidate samplers, apply primary sampler logic to these candidates only
|
||||
if candidate_samplers:
|
||||
# Collect potential primary samplers based on different criteria
|
||||
custom_advanced_samplers = []
|
||||
advanced_add_noise_samplers = []
|
||||
high_denoise_samplers = []
|
||||
max_denoise = -1
|
||||
high_denoise_id = None
|
||||
|
||||
# First, check for SamplerCustomAdvanced among candidates
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt and prompt.original_prompt:
|
||||
for node_id in candidate_samplers:
|
||||
node_info = prompt.original_prompt.get(node_id, {})
|
||||
if node_info.get("class_type") == "SamplerCustomAdvanced":
|
||||
custom_advanced_samplers.append(node_id)
|
||||
|
||||
# Next, check for KSamplerAdvanced with add_noise="enable" among candidates
|
||||
for node_id, sampler_info in candidate_samplers.items():
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
add_noise = parameters.get("add_noise")
|
||||
if add_noise == "enable":
|
||||
advanced_add_noise_samplers.append(node_id)
|
||||
|
||||
# Find the sampler with highest denoise value among candidates
|
||||
for node_id, sampler_info in candidate_samplers.items():
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
denoise = parameters.get("denoise")
|
||||
if denoise is not None and denoise > max_denoise:
|
||||
max_denoise = denoise
|
||||
high_denoise_id = node_id
|
||||
|
||||
if high_denoise_id:
|
||||
high_denoise_samplers.append(high_denoise_id)
|
||||
|
||||
# Combine all potential primary samplers
|
||||
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
|
||||
|
||||
# Find the most recent potential primary sampler (closest to downstream node)
|
||||
for i in range(downstream_index - 1, -1, -1):
|
||||
node_id = execution_order[i]
|
||||
if node_id in potential_samplers:
|
||||
return node_id, candidate_samplers[node_id]
|
||||
|
||||
# If no potential sampler found from our criteria, return the most recent sampler
|
||||
if candidate_samplers:
|
||||
for i in range(downstream_index - 1, -1, -1):
|
||||
node_id = execution_order[i]
|
||||
if node_id in candidate_samplers:
|
||||
return node_id, candidate_samplers[node_id]
|
||||
|
||||
# If no downstream_id provided or no suitable sampler found, fall back to original logic
|
||||
primary_sampler = None
|
||||
primary_sampler_id = None
|
||||
max_denoise = -1
|
||||
|
||||
# First, check for SamplerCustomAdvanced
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt and prompt.original_prompt:
|
||||
for node_id, node_info in prompt.original_prompt.items():
|
||||
if node_info.get("class_type") == "SamplerCustomAdvanced":
|
||||
# Check if the node is in SAMPLING and has IS_SAMPLER flag
|
||||
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
|
||||
return node_id, metadata[SAMPLING][node_id]
|
||||
|
||||
# Next, check for KSamplerAdvanced with add_noise="enable" using IS_SAMPLER flag
|
||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||
# Skip if not marked as a sampler
|
||||
if not sampler_info.get(IS_SAMPLER, False):
|
||||
continue
|
||||
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
add_noise = parameters.get("add_noise")
|
||||
if add_noise == "enable":
|
||||
primary_sampler = sampler_info
|
||||
primary_sampler_id = node_id
|
||||
break
|
||||
|
||||
# If no specialized sampler found, find the sampler with highest denoise value
|
||||
if primary_sampler is None:
|
||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||
# Skip if not marked as a sampler
|
||||
if not sampler_info.get(IS_SAMPLER, False):
|
||||
continue
|
||||
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
denoise = parameters.get("denoise")
|
||||
if denoise is not None and denoise > max_denoise:
|
||||
max_denoise = denoise
|
||||
primary_sampler = sampler_info
|
||||
primary_sampler_id = node_id
|
||||
|
||||
return primary_sampler_id, primary_sampler
|
||||
|
||||
@staticmethod
|
||||
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
|
||||
"""
|
||||
Trace an input connection from a node to find the source node
|
||||
|
||||
Parameters:
|
||||
- prompt: The prompt object containing node connections
|
||||
- node_id: ID of the starting node
|
||||
- input_name: Name of the input to trace
|
||||
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
|
||||
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
- node_id of the found node, or None if not found
|
||||
"""
|
||||
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
|
||||
return None
|
||||
|
||||
# For depth tracking
|
||||
current_depth = 0
|
||||
|
||||
current_node_id = node_id
|
||||
current_input = input_name
|
||||
|
||||
# If we're just tracing to origin (no target_class), keep track of the last valid node
|
||||
last_valid_node = None
|
||||
|
||||
while current_depth < max_depth:
|
||||
if current_node_id not in prompt.original_prompt:
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
|
||||
if current_input not in node_inputs:
|
||||
# We've reached a node without the specified input - this is our origin node
|
||||
# if we're not looking for a specific target_class
|
||||
return current_node_id if not target_class else None
|
||||
|
||||
input_value = node_inputs[current_input]
|
||||
# Input connections are formatted as [node_id, output_index]
|
||||
if isinstance(input_value, list) and len(input_value) >= 2:
|
||||
found_node_id = input_value[0] # Connected node_id
|
||||
|
||||
# If we're looking for a specific node class
|
||||
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
|
||||
return found_node_id
|
||||
|
||||
# If we're not looking for a specific class, update the last valid node
|
||||
if not target_class:
|
||||
last_valid_node = found_node_id
|
||||
|
||||
# Continue tracing through intermediate nodes
|
||||
current_node_id = found_node_id
|
||||
# For most conditioning nodes, the input we want to follow is named "conditioning"
|
||||
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
|
||||
current_input = "conditioning"
|
||||
else:
|
||||
# If there's no "conditioning" input, return the current node
|
||||
# if we're not looking for a specific target_class
|
||||
return found_node_id if not target_class else None
|
||||
else:
|
||||
# We've reached a node with no further connections
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
current_depth += 1
|
||||
|
||||
# If we've reached max depth without finding target_class
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
@staticmethod
|
||||
def find_primary_checkpoint(metadata):
|
||||
"""Find the primary checkpoint model in the workflow"""
|
||||
if not metadata.get(MODELS):
|
||||
return None
|
||||
|
||||
# In most workflows, there's only one checkpoint, so we can just take the first one
|
||||
for node_id, model_info in metadata.get(MODELS, {}).items():
|
||||
if model_info.get("type") == "checkpoint":
|
||||
return model_info.get("name")
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def match_conditioning_to_prompts(metadata, sampler_id):
|
||||
"""
|
||||
Match conditioning objects from a sampler to prompts in metadata
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- sampler_id: ID of the sampler node to match
|
||||
|
||||
Returns:
|
||||
- Dictionary with 'prompt' and 'negative_prompt' if found
|
||||
"""
|
||||
result = {
|
||||
"prompt": "",
|
||||
"negative_prompt": ""
|
||||
}
|
||||
|
||||
# Check if we have stored conditioning objects for this sampler
|
||||
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
|
||||
return ""
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def extract_generation_params(metadata, id=None):
|
||||
"""
|
||||
Extract generation parameters from metadata using node relationships
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
params = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"seed": None,
|
||||
"steps": None,
|
||||
"cfg_scale": None,
|
||||
# "guidance": None, # Add guidance parameter
|
||||
"sampler": None,
|
||||
"scheduler": None,
|
||||
"checkpoint": None,
|
||||
"loras": "",
|
||||
"size": None,
|
||||
"clip_skip": None
|
||||
}
|
||||
|
||||
# Get the prompt object for node relationship tracing
|
||||
prompt = metadata.get("current_prompt")
|
||||
|
||||
# Find the primary KSampler node
|
||||
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
||||
|
||||
# Directly get checkpoint from metadata instead of tracing
|
||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
|
||||
if checkpoint:
|
||||
params["checkpoint"] = checkpoint
|
||||
|
||||
# Check if guidance parameter exists in any sampling node
|
||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
if "guidance" in parameters and parameters["guidance"] is not None:
|
||||
params["guidance"] = parameters["guidance"]
|
||||
break
|
||||
|
||||
if primary_sampler:
|
||||
# Extract sampling parameters
|
||||
sampling_params = primary_sampler.get("parameters", {})
|
||||
# Handle both seed and noise_seed
|
||||
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
|
||||
params["steps"] = sampling_params.get("steps")
|
||||
params["cfg_scale"] = sampling_params.get("cfg")
|
||||
params["sampler"] = sampling_params.get("sampler_name")
|
||||
params["scheduler"] = sampling_params.get("scheduler")
|
||||
|
||||
if prompt and primary_sampler_id:
|
||||
# Check if this is a SamplerCustomAdvanced node
|
||||
is_custom_advanced = False
|
||||
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
|
||||
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||
|
||||
if is_custom_advanced:
|
||||
# For SamplerCustomAdvanced, use the new handler method
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
else:
|
||||
# For standard samplers, match conditioning objects to prompts
|
||||
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
|
||||
params["prompt"] = prompt_results["prompt"]
|
||||
params["negative_prompt"] = prompt_results["negative_prompt"]
|
||||
|
||||
# If prompts were still not found, fall back to tracing connections
|
||||
if not params["prompt"]:
|
||||
# Original tracing for standard samplers
|
||||
# Trace positive prompt - look specifically for CLIPTextEncode
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
else:
|
||||
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
|
||||
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
|
||||
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
|
||||
|
||||
# Trace negative prompt - look specifically for CLIPTextEncode
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
|
||||
# For SamplerCustom, handle any additional parameters
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
# Size extraction is same for all sampler types
|
||||
# Check if the sampler itself has size information (from latent_image)
|
||||
if primary_sampler_id in metadata.get(SIZE, {}):
|
||||
width = metadata[SIZE][primary_sampler_id].get("width")
|
||||
height = metadata[SIZE][primary_sampler_id].get("height")
|
||||
if width and height:
|
||||
params["size"] = f"{width}x{height}"
|
||||
|
||||
# Extract LoRAs using the standardized format
|
||||
lora_parts = []
|
||||
for node_id, lora_info in metadata.get(LORAS, {}).items():
|
||||
# Access the lora_list from the standardized format
|
||||
lora_list = lora_info.get("lora_list", [])
|
||||
for lora in lora_list:
|
||||
name = lora.get("name", "unknown")
|
||||
strength = lora.get("strength", 1.0)
|
||||
lora_parts.append(f"<lora:{name}:{strength}>")
|
||||
|
||||
params["loras"] = " ".join(lora_parts)
|
||||
|
||||
# Set default clip_skip value
|
||||
params["clip_skip"] = "1" # Common default
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def to_dict(metadata, id=None):
|
||||
"""
|
||||
Convert extracted metadata to the ComfyUI output.json format
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
if standalone_mode:
|
||||
# Return empty dictionary in standalone mode
|
||||
return {}
|
||||
|
||||
params = MetadataProcessor.extract_generation_params(metadata, id)
|
||||
|
||||
# Convert all values to strings to match output.json format
|
||||
for key in params:
|
||||
if params[key] is not None:
|
||||
params[key] = str(params[key])
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def to_json(metadata, id=None):
|
||||
"""Convert metadata to JSON string"""
|
||||
params = MetadataProcessor.to_dict(metadata, id)
|
||||
return json.dumps(params, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
|
||||
"""
|
||||
Handle parameter extraction for SamplerCustomAdvanced nodes
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- prompt: The prompt object containing node connections
|
||||
- primary_sampler_id: ID of the SamplerCustomAdvanced node
|
||||
- params: Parameters dictionary to update
|
||||
"""
|
||||
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
|
||||
return
|
||||
|
||||
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
|
||||
if "sigmas" in sampler_inputs:
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||
if "sampler" in sampler_inputs:
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
if "guider" in sampler_inputs:
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
277
py/metadata_collector/metadata_registry.py
Normal file
277
py/metadata_collector/metadata_registry.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import time
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
||||
from .constants import METADATA_CATEGORIES, IMAGES
|
||||
|
||||
class MetadataRegistry:
|
||||
"""A singleton registry to store and retrieve workflow metadata"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._reset()
|
||||
return cls._instance
|
||||
|
||||
def _reset(self):
|
||||
self.current_prompt_id = None
|
||||
self.current_prompt = None
|
||||
self.metadata = {}
|
||||
self.prompt_metadata = {}
|
||||
self.executed_nodes = set()
|
||||
|
||||
# Node-level cache for metadata
|
||||
self.node_cache = {}
|
||||
|
||||
# Limit the number of stored prompts
|
||||
self.max_prompt_history = 3
|
||||
|
||||
# Categories we want to track and retrieve from cache
|
||||
self.metadata_categories = METADATA_CATEGORIES
|
||||
|
||||
def _clean_old_prompts(self):
|
||||
"""Clean up old prompt metadata, keeping only recent ones"""
|
||||
if len(self.prompt_metadata) <= self.max_prompt_history:
|
||||
return
|
||||
|
||||
# Sort all prompt_ids by timestamp
|
||||
sorted_prompts = sorted(
|
||||
self.prompt_metadata.keys(),
|
||||
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
|
||||
)
|
||||
|
||||
# Remove oldest records
|
||||
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
|
||||
for pid in prompts_to_remove:
|
||||
del self.prompt_metadata[pid]
|
||||
|
||||
def start_collection(self, prompt_id):
|
||||
"""Begin metadata collection for a new prompt"""
|
||||
self.current_prompt_id = prompt_id
|
||||
self.executed_nodes = set()
|
||||
self.prompt_metadata[prompt_id] = {
|
||||
category: {} for category in METADATA_CATEGORIES
|
||||
}
|
||||
# Add additional metadata fields
|
||||
self.prompt_metadata[prompt_id].update({
|
||||
"execution_order": [],
|
||||
"current_prompt": None, # Will store the prompt object
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
# Clean up old prompt data
|
||||
self._clean_old_prompts()
|
||||
|
||||
def set_current_prompt(self, prompt):
|
||||
"""Set the current prompt object reference"""
|
||||
self.current_prompt = prompt
|
||||
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
|
||||
# Store the prompt in the metadata for later relationship tracing
|
||||
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
|
||||
|
||||
def get_metadata(self, prompt_id=None):
|
||||
"""Get collected metadata for a prompt"""
|
||||
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
||||
if key not in self.prompt_metadata:
|
||||
return {}
|
||||
|
||||
metadata = self.prompt_metadata[key]
|
||||
|
||||
# If we have a current prompt object, check for non-executed nodes
|
||||
prompt_obj = metadata.get("current_prompt")
|
||||
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
||||
original_prompt = prompt_obj.original_prompt
|
||||
|
||||
# Fill in missing metadata from cache for nodes that weren't executed
|
||||
self._fill_missing_metadata(key, original_prompt)
|
||||
|
||||
return self.prompt_metadata.get(key, {})
|
||||
|
||||
def _fill_missing_metadata(self, prompt_id, original_prompt):
|
||||
"""Fill missing metadata from cache for non-executed nodes"""
|
||||
if not original_prompt:
|
||||
return
|
||||
|
||||
executed_nodes = self.executed_nodes
|
||||
metadata = self.prompt_metadata[prompt_id]
|
||||
|
||||
# Iterate through nodes in the original prompt
|
||||
for node_id, node_data in original_prompt.items():
|
||||
# Skip if already executed in this run
|
||||
if node_id in executed_nodes:
|
||||
continue
|
||||
|
||||
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
|
||||
prompt_class_type = node_data.get("class_type")
|
||||
if not prompt_class_type:
|
||||
continue
|
||||
|
||||
# Convert to actual class name (which is what we use in our cache)
|
||||
class_type = prompt_class_type
|
||||
if prompt_class_type in NODE_CLASS_MAPPINGS:
|
||||
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
|
||||
class_type = class_obj.__name__
|
||||
|
||||
# Create cache key using the actual class name
|
||||
cache_key = f"{node_id}:{class_type}"
|
||||
|
||||
# Check if this node type is relevant for metadata collection
|
||||
if class_type in NODE_EXTRACTORS:
|
||||
# Check if we have cached metadata for this node
|
||||
if cache_key in self.node_cache:
|
||||
cached_data = self.node_cache[cache_key]
|
||||
|
||||
# Apply cached metadata to the current metadata
|
||||
for category in self.metadata_categories:
|
||||
if category in cached_data and node_id in cached_data[category]:
|
||||
if node_id not in metadata[category]:
|
||||
metadata[category][node_id] = cached_data[category][node_id]
|
||||
|
||||
def record_node_execution(self, node_id, class_type, inputs, outputs):
|
||||
"""Record information about a node's execution"""
|
||||
if not self.current_prompt_id:
|
||||
return
|
||||
|
||||
# Add to execution order and mark as executed
|
||||
if node_id not in self.executed_nodes:
|
||||
self.executed_nodes.add(node_id)
|
||||
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
|
||||
|
||||
# Process inputs to simplify working with them
|
||||
processed_inputs = {}
|
||||
for input_name, input_values in inputs.items():
|
||||
if isinstance(input_values, list) and len(input_values) > 0:
|
||||
# For single values, just use the first one (most common case)
|
||||
processed_inputs[input_name] = input_values[0]
|
||||
else:
|
||||
processed_inputs[input_name] = input_values
|
||||
|
||||
# Extract node-specific metadata
|
||||
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||
extractor.extract(
|
||||
node_id,
|
||||
processed_inputs,
|
||||
outputs,
|
||||
self.prompt_metadata[self.current_prompt_id]
|
||||
)
|
||||
|
||||
# Cache this node's metadata
|
||||
self._cache_node_metadata(node_id, class_type)
|
||||
|
||||
def update_node_execution(self, node_id, class_type, outputs):
|
||||
"""Update node metadata with output information"""
|
||||
if not self.current_prompt_id:
|
||||
return
|
||||
|
||||
# Process outputs to make them more usable
|
||||
processed_outputs = outputs
|
||||
|
||||
# Use the same extractor to update with outputs
|
||||
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||
if hasattr(extractor, 'update'):
|
||||
extractor.update(
|
||||
node_id,
|
||||
processed_outputs,
|
||||
self.prompt_metadata[self.current_prompt_id]
|
||||
)
|
||||
|
||||
# Update the cached metadata for this node
|
||||
self._cache_node_metadata(node_id, class_type)
|
||||
|
||||
def _cache_node_metadata(self, node_id, class_type):
|
||||
"""Cache the metadata for a specific node"""
|
||||
if not self.current_prompt_id or not node_id or not class_type:
|
||||
return
|
||||
|
||||
# Create a cache key combining node_id and class_type
|
||||
cache_key = f"{node_id}:{class_type}"
|
||||
|
||||
# Create a shallow copy of the node's metadata
|
||||
node_metadata = {}
|
||||
current_metadata = self.prompt_metadata[self.current_prompt_id]
|
||||
|
||||
for category in self.metadata_categories:
|
||||
if category in current_metadata and node_id in current_metadata[category]:
|
||||
if category not in node_metadata:
|
||||
node_metadata[category] = {}
|
||||
node_metadata[category][node_id] = current_metadata[category][node_id]
|
||||
|
||||
# Save new metadata or clear stale cache entries when metadata is empty
|
||||
if any(node_metadata.values()):
|
||||
self.node_cache[cache_key] = node_metadata
|
||||
else:
|
||||
self.node_cache.pop(cache_key, None)
|
||||
|
||||
def clear_unused_cache(self):
|
||||
"""Clean up node_cache entries that are no longer in use"""
|
||||
# Collect all node_ids currently in prompt_metadata
|
||||
active_node_ids = set()
|
||||
for prompt_data in self.prompt_metadata.values():
|
||||
for category in self.metadata_categories:
|
||||
if category in prompt_data:
|
||||
active_node_ids.update(prompt_data[category].keys())
|
||||
|
||||
# Find cache keys that are no longer needed
|
||||
keys_to_remove = []
|
||||
for cache_key in self.node_cache:
|
||||
node_id = cache_key.split(':')[0]
|
||||
if node_id not in active_node_ids:
|
||||
keys_to_remove.append(cache_key)
|
||||
|
||||
# Remove cache entries that are no longer needed
|
||||
for key in keys_to_remove:
|
||||
del self.node_cache[key]
|
||||
|
||||
def clear_metadata(self, prompt_id=None):
|
||||
"""Clear metadata for a specific prompt or reset all data"""
|
||||
if prompt_id is not None:
|
||||
if prompt_id in self.prompt_metadata:
|
||||
del self.prompt_metadata[prompt_id]
|
||||
# Clean up cache after removing prompt
|
||||
self.clear_unused_cache()
|
||||
else:
|
||||
# Reset all data
|
||||
self._reset()
|
||||
|
||||
def get_first_decoded_image(self, prompt_id=None):
|
||||
"""Get the first decoded image result"""
|
||||
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
||||
if key not in self.prompt_metadata:
|
||||
return None
|
||||
|
||||
metadata = self.prompt_metadata[key]
|
||||
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
||||
image_data = metadata[IMAGES]["first_decode"]["image"]
|
||||
|
||||
# If it's an image batch or tuple, handle various formats
|
||||
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
||||
# Return first element of list/tuple
|
||||
return image_data[0]
|
||||
|
||||
# If it's a tensor, return as is for processing in the route handler
|
||||
return image_data
|
||||
|
||||
# If no image is found in the current metadata, try to find it in the cache
|
||||
# This handles the case where VAEDecode was cached by ComfyUI and not executed
|
||||
prompt_obj = metadata.get("current_prompt")
|
||||
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
||||
original_prompt = prompt_obj.original_prompt
|
||||
for node_id, node_data in original_prompt.items():
|
||||
class_type = node_data.get("class_type")
|
||||
if class_type and class_type in NODE_CLASS_MAPPINGS:
|
||||
class_obj = NODE_CLASS_MAPPINGS[class_type]
|
||||
class_name = class_obj.__name__
|
||||
# Check if this is a VAEDecode node
|
||||
if class_name == "VAEDecode":
|
||||
# Try to find this node in the cache
|
||||
cache_key = f"{node_id}:{class_name}"
|
||||
if cache_key in self.node_cache:
|
||||
cached_data = self.node_cache[cache_key]
|
||||
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
|
||||
image_data = cached_data[IMAGES][node_id]["image"]
|
||||
# Handle different image formats
|
||||
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
||||
return image_data[0]
|
||||
return image_data
|
||||
|
||||
return None
|
||||
735
py/metadata_collector/node_extractors.py
Normal file
735
py/metadata_collector/node_extractors.py
Normal file
@@ -0,0 +1,735 @@
|
||||
import os
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER
|
||||
|
||||
|
||||
def _store_checkpoint_metadata(metadata, node_id, model_name):
|
||||
"""Store checkpoint model information when available."""
|
||||
if not model_name:
|
||||
return
|
||||
metadata.setdefault(MODELS, {})
|
||||
metadata[MODELS][node_id] = {
|
||||
"name": model_name,
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
|
||||
class NodeMetadataExtractor:
|
||||
"""Base class for node-specific metadata extraction"""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
"""Extract metadata from node inputs/outputs"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
"""Update metadata with node outputs after execution"""
|
||||
pass
|
||||
|
||||
class GenericNodeExtractor(NodeMetadataExtractor):
|
||||
"""Default extractor for nodes without specific handling"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
pass
|
||||
|
||||
class CheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "ckpt_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("ckpt_name")
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
|
||||
class NunchakuFluxDiTLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "model_path" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("model_path")
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
|
||||
class NunchakuQwenImageDiTLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "model_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("model_name")
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
class GGUFLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "gguf_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("gguf_name")
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
|
||||
class KJNodesModelLoaderExtractor(NodeMetadataExtractor):
|
||||
"""Extract metadata from KJNodes loaders that expose `model_name`."""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "model_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("model_name")
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
class TSCCheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "ckpt_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("ckpt_name")
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
# For loader node has lora_stack input, like Efficient Loader from Efficient Nodes
|
||||
active_loras = []
|
||||
|
||||
# Process lora_stack if available
|
||||
if "lora_stack" in inputs:
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name from path (following the format in lora_loader.py)
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": model_strength
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# Extract positive and negative prompt text if available
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
# Store both positive and negative text
|
||||
metadata[PROMPTS][node_id]["positive_text"] = positive_text
|
||||
metadata[PROMPTS][node_id]["negative_text"] = negative_text
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Handle conditioning outputs from TSC_EfficientLoader
|
||||
# outputs is a list with [(model, positive_encoded, negative_encoded, {"samples":latent}, vae, clip, dependencies,)]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 3:
|
||||
positive_conditioning = first_output[1]
|
||||
negative_conditioning = first_output[2]
|
||||
|
||||
# Save both conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = positive_conditioning
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = negative_conditioning
|
||||
|
||||
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "text" not in inputs:
|
||||
return
|
||||
|
||||
text = inputs.get("text", "")
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"text": text,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@staticmethod
|
||||
def extract_sampling_params(node_id, inputs, metadata, param_keys):
|
||||
"""Extract sampling parameters from inputs"""
|
||||
sampling_params = {}
|
||||
for key in param_keys:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_conditioning(node_id, inputs, metadata):
|
||||
"""Extract conditioning objects from inputs"""
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
|
||||
# Save conditioning objects in metadata for later matching
|
||||
if pos_conditioning is not None or neg_conditioning is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
@staticmethod
|
||||
def extract_latent_dimensions(node_id, inputs, metadata):
|
||||
"""Extract dimensions from latent image"""
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class SamplerExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerAdvancedBasicPipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Store vae_decode setting for later use in update
|
||||
if inputs and "vae_decode" in inputs:
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
# Store the vae_decode setting
|
||||
metadata[SAMPLING][node_id]["vae_decode"] = inputs["vae_decode"]
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Check if vae_decode was set to "true"
|
||||
should_save_image = True
|
||||
if SAMPLING in metadata and node_id in metadata[SAMPLING]:
|
||||
vae_decode = metadata[SAMPLING][node_id].get("vae_decode")
|
||||
if vae_decode is not None:
|
||||
should_save_image = (vae_decode == "true")
|
||||
|
||||
# Skip image saving if vae_decode isn't "true"
|
||||
if not should_save_image:
|
||||
return
|
||||
|
||||
# Ensure IMAGES category exists
|
||||
if IMAGES not in metadata:
|
||||
metadata[IMAGES] = {}
|
||||
|
||||
# Extract output_images from the TSC sampler format
|
||||
# outputs = [{"ui": {"images": preview_images}, "result": result}]
|
||||
# where result = (original_model, original_positive, original_negative, latent_list, optional_vae, output_images,)
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
# Get the first item in the list
|
||||
output_item = outputs[0]
|
||||
if isinstance(output_item, dict) and "result" in output_item:
|
||||
result = output_item["result"]
|
||||
if isinstance(result, tuple) and len(result) >= 6:
|
||||
# The output_images is the last element in the result tuple
|
||||
output_images = (result[5],)
|
||||
|
||||
# Save image data under node ID index to be captured by caching mechanism
|
||||
metadata[IMAGES][node_id] = {
|
||||
"node_id": node_id,
|
||||
"image": output_images
|
||||
}
|
||||
|
||||
# Only set first_decode if it hasn't been recorded yet
|
||||
if "first_decode" not in metadata[IMAGES]:
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
|
||||
|
||||
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
|
||||
class LoraLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "lora_name" not in inputs:
|
||||
return
|
||||
|
||||
lora_name = inputs.get("lora_name")
|
||||
# Extract base filename without extension from path
|
||||
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
|
||||
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
|
||||
|
||||
# Use the standardized format with lora_list
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": [
|
||||
{
|
||||
"name": lora_name,
|
||||
"strength": strength_model
|
||||
}
|
||||
],
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class ImageSizeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
width = inputs.get("width", 512)
|
||||
height = inputs.get("height", 512)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
active_loras = []
|
||||
|
||||
# Process lora_stack if available
|
||||
if "lora_stack" in inputs:
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name from path (following the format in lora_loader.py)
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": model_strength
|
||||
})
|
||||
|
||||
# Process loras from inputs
|
||||
if "loras" in inputs:
|
||||
loras_data = inputs.get("loras", [])
|
||||
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
loras_list = loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
loras_list = loras_data
|
||||
else:
|
||||
loras_list = []
|
||||
|
||||
# Filter for active loras
|
||||
for lora in loras_list:
|
||||
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
|
||||
active_loras.append({
|
||||
"name": lora.get("name", ""),
|
||||
"strength": float(lora.get("strength", 1.0))
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class FluxGuidanceExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "guidance" not in inputs:
|
||||
return
|
||||
|
||||
guidance_value = inputs.get("guidance")
|
||||
|
||||
# Store the guidance value in SAMPLING category
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
||||
|
||||
class UNETLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "unet_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("unet_name")
|
||||
if model_name:
|
||||
metadata[MODELS][node_id] = {
|
||||
"name": model_name,
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class VAEDecodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Ensure IMAGES category exists
|
||||
if IMAGES not in metadata:
|
||||
metadata[IMAGES] = {}
|
||||
|
||||
# Save image data under node ID index to be captured by caching mechanism
|
||||
metadata[IMAGES][node_id] = {
|
||||
"node_id": node_id,
|
||||
"image": outputs
|
||||
}
|
||||
|
||||
# Only set first_decode if it hasn't been recorded yet
|
||||
if "first_decode" not in metadata[IMAGES]:
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class KSamplerSelectExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "sampler_name" not in inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
if "sampler_name" in inputs:
|
||||
sampling_params["sampler_name"] = inputs["sampler_name"]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class BasicSchedulerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
for key in ["scheduler", "steps", "denoise"]:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
|
||||
# Handle noise.seed as seed
|
||||
if "noise" in inputs and inputs["noise"] is not None and hasattr(inputs["noise"], "seed"):
|
||||
noise = inputs["noise"]
|
||||
sampling_params["seed"] = noise.seed
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "clip_l" not in inputs or "t5xxl" not in inputs:
|
||||
return
|
||||
|
||||
clip_l_text = inputs.get("clip_l", "")
|
||||
t5xxl_text = inputs.get("t5xxl", "")
|
||||
|
||||
# If both are empty, use empty string
|
||||
if not clip_l_text and not t5xxl_text:
|
||||
combined_text = ""
|
||||
# If one is empty, use the non-empty one
|
||||
elif not clip_l_text:
|
||||
combined_text = t5xxl_text
|
||||
elif not t5xxl_text:
|
||||
combined_text = clip_l_text
|
||||
# If both have content, use JSON format
|
||||
else:
|
||||
combined_text = json.dumps({
|
||||
"T5": t5xxl_text,
|
||||
"CLIP-L": clip_l_text
|
||||
})
|
||||
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"text": combined_text,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# Extract guidance value if available
|
||||
if "guidance" in inputs:
|
||||
guidance_value = inputs.get("guidance")
|
||||
|
||||
# Store the guidance value in SAMPLING category
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "cfg" not in inputs:
|
||||
return
|
||||
|
||||
cfg_value = inputs.get("cfg")
|
||||
|
||||
# Store the cfg value in SAMPLING category
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
||||
|
||||
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Save the original conditioning inputs
|
||||
base_positive = inputs.get("base_positive")
|
||||
base_negative = inputs.get("base_negative")
|
||||
|
||||
if base_positive is not None or base_negative is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
|
||||
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Extract transformed conditionings from outputs
|
||||
# outputs structure: [(base_positive, base_negative, show_help, )]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 2:
|
||||
transformed_positive = first_output[0]
|
||||
transformed_negative = first_output[1]
|
||||
|
||||
# Save transformed conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
|
||||
|
||||
# Registry of node-specific extractors
|
||||
# Keys are node class names
|
||||
NODE_EXTRACTORS = {
|
||||
# Sampling
|
||||
"KSampler": SamplerExtractor,
|
||||
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||
"SamplerCustom": KSamplerAdvancedExtractor,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||
"ClownsharKSampler_Beta": SamplerExtractor,
|
||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack
|
||||
"KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack
|
||||
"KSampler_inspire": SamplerExtractor, # comfyui-inspire-pack
|
||||
# Sampling Selectors
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||
# Loaders
|
||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||
"NunchakuFluxDiTLoader": NunchakuFluxDiTLoaderExtractor, # ComfyUI-Nunchaku
|
||||
"NunchakuQwenImageDiTLoader": NunchakuQwenImageDiTLoaderExtractor, # ComfyUI-Nunchaku
|
||||
"LoaderGGUF": GGUFLoaderExtractor, # calcuis gguf
|
||||
"LoaderGGUFAdvanced": GGUFLoaderExtractor, # calcuis gguf
|
||||
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
"LoraManagerLoader": LoraLoaderManagerExtractor,
|
||||
# Conditioning
|
||||
"CLIPTextEncode": CLIPTextEncodeExtractor,
|
||||
"PromptLoraManager": CLIPTextEncodeExtractor,
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
|
||||
"CFGGuider": CFGGuiderExtractor, # Add CFGGuider
|
||||
# Image
|
||||
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
|
||||
# Add other nodes as needed
|
||||
}
|
||||
1
py/middleware/__init__.py
Normal file
1
py/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Server middleware modules"""
|
||||
53
py/middleware/cache_middleware.py
Normal file
53
py/middleware/cache_middleware.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Cache control middleware for ComfyUI server"""
|
||||
|
||||
from aiohttp import web
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
# Time in seconds
|
||||
ONE_HOUR: int = 3600
|
||||
ONE_DAY: int = 86400
|
||||
IMG_EXTENSIONS = (
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".ppm",
|
||||
".bmp",
|
||||
".pgm",
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
".mp4"
|
||||
)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(
|
||||
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> web.Response:
|
||||
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
|
||||
response: web.Response = await handler(request)
|
||||
|
||||
if (
|
||||
request.path.endswith(".js")
|
||||
or request.path.endswith(".css")
|
||||
or request.path.endswith("index.json")
|
||||
):
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
if not request.path.lower().endswith(IMG_EXTENSIONS):
|
||||
return response
|
||||
|
||||
# Handle image files
|
||||
if response.status == 404:
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
|
||||
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
|
||||
# Success responses and permanent redirects - cache for 1 day
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
|
||||
elif response.status in (302, 303, 307):
|
||||
# Temporary redirects - no cache
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
# Note: 304 Not Modified falls through - no cache headers set
|
||||
|
||||
return response
|
||||
45
py/nodes/debug_metadata.py
Normal file
45
py/nodes/debug_metadata.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
from server import PromptServer # type: ignore
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DebugMetadata:
|
||||
NAME = "Debug Metadata (LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Debug node to verify metadata_processor functionality"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "process_metadata"
|
||||
|
||||
def process_metadata(self, images, id):
|
||||
try:
|
||||
# Get the current execution context's metadata
|
||||
from ..metadata_collector import get_metadata
|
||||
metadata = get_metadata()
|
||||
|
||||
# Use the MetadataProcessor to convert it to JSON string
|
||||
metadata_json = MetadataProcessor.to_json(metadata, id)
|
||||
|
||||
# Send metadata to frontend for display
|
||||
PromptServer.instance.send_sync("metadata_update", {
|
||||
"id": id,
|
||||
"metadata": metadata_json
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing metadata: {e}")
|
||||
|
||||
return ()
|
||||
@@ -1,26 +1,24 @@
|
||||
import logging
|
||||
import re
|
||||
from nodes import LoraLoader
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraManagerLoader:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
# "clip": ("CLIP",),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"text": ("STRING", {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -29,51 +27,9 @@ class LoraManagerLoader:
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def _get_loras_list(self, kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
|
||||
def load_loras(self, model, text, **kwargs):
|
||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||
@@ -82,34 +38,71 @@ class LoraManagerLoader:
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the provided path and strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Then process loras from kwargs with support for both old and new formats
|
||||
loras_list = self._get_loras_list(kwargs)
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
strength = float(lora['strength'])
|
||||
model_strength = float(lora['strength'])
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the resolved path
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
||||
loaded_loras.append(f"{lora_name}: {strength}")
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
@@ -117,8 +110,160 @@ class LoraManagerLoader:
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras as <lora:lora_name:strength> separated by spaces
|
||||
formatted_loras = " ".join([f"<lora:{name.split(':')[0].strip()}:{str(strength).strip()}>"
|
||||
for name, strength in [item.split(':') for item in loaded_loras]])
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0]
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras)
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
class LoraManagerTextLoader:
|
||||
NAME = "LoRA Text Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": ("STRING", {
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "STRING", "STRING")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
"""Parse LoRA syntax from text input."""
|
||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
|
||||
loras = []
|
||||
for match in matches:
|
||||
lora_name = match[0]
|
||||
model_strength = float(match[1])
|
||||
clip_strength = float(match[2]) if match[2] else model_strength
|
||||
|
||||
loras.append({
|
||||
'name': lora_name,
|
||||
'model_strength': model_strength,
|
||||
'clip_strength': clip_strength
|
||||
})
|
||||
|
||||
return loras
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
"""Load LoRAs based on text syntax input."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Parse and process LoRAs from text syntax
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora['name']
|
||||
model_strength = lora['model_strength']
|
||||
clip_strength = lora['clip_strength']
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0].strip()
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
@@ -1,9 +1,7 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -16,8 +14,9 @@ class LoraStacker:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": (IO.STRING, {
|
||||
"text": ("STRING", {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -26,51 +25,9 @@ class LoraStacker:
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
|
||||
RETURN_TYPES = ("LORA_STACK", "STRING", "STRING")
|
||||
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
||||
FUNCTION = "stack_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def _get_loras_list(self, kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
|
||||
def stack_loras(self, text, **kwargs):
|
||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||
@@ -80,39 +37,49 @@ class LoraStacker:
|
||||
|
||||
# Process existing lora_stack if available
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
if lora_stack:
|
||||
if (lora_stack):
|
||||
stack.extend(lora_stack)
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_list = self._get_loras_list(kwargs)
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Add to stack without loading
|
||||
# replace '/' with os.sep to avoid different OS path format
|
||||
stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength))
|
||||
active_loras.append((lora_name, model_strength))
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
# Format active_loras as <lora:lora_name:strength> separated by spaces
|
||||
active_loras_text = " ".join([f"<lora:{name}:{str(strength).strip()}>"
|
||||
for name, strength in active_loras])
|
||||
|
||||
# Format active_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (stack, trigger_words_text, active_loras_text)
|
||||
|
||||
59
py/nodes/prompt.py
Normal file
59
py/nodes/prompt.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
class PromptLoraManager:
|
||||
"""Encodes text (and optional trigger words) into CLIP conditioning."""
|
||||
|
||||
NAME = "Prompt (LoraManager)"
|
||||
CATEGORY = "Lora Manager/conditioning"
|
||||
DESCRIPTION = (
|
||||
"Encodes a text prompt using a CLIP model into an embedding that can be used "
|
||||
"to guide the diffusion model towards generating specific images."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": (
|
||||
'STRING',
|
||||
{
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "The text to be encoded.",
|
||||
},
|
||||
),
|
||||
"clip": (
|
||||
'CLIP',
|
||||
{"tooltip": "The CLIP model used for encoding the text."},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"trigger_words": (
|
||||
'STRING',
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": (
|
||||
"Optional trigger words to prepend to the text before "
|
||||
"encoding."
|
||||
)
|
||||
},
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ('CONDITIONING', 'STRING',)
|
||||
RETURN_NAMES = ('CONDITIONING', 'PROMPT',)
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"A conditioning containing the embedded text used to guide the diffusion model.",
|
||||
)
|
||||
FUNCTION = "encode"
|
||||
|
||||
def encode(self, text: str, clip: Any, trigger_words: Optional[str] = None):
|
||||
prompt = text
|
||||
if trigger_words:
|
||||
prompt = ", ".join([trigger_words, text])
|
||||
|
||||
from nodes import CLIPTextEncode # type: ignore
|
||||
conditioning = CLIPTextEncode().encode(clip, prompt)[0]
|
||||
return (conditioning, prompt,)
|
||||
@@ -1,14 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import re
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..workflow.parser import WorkflowParser
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
from io import BytesIO
|
||||
|
||||
class SaveImage:
|
||||
NAME = "Save Image (LoraManager)"
|
||||
@@ -30,17 +29,36 @@ class SaveImage:
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"file_format": (["png", "jpeg", "webp"],),
|
||||
"filename_prefix": ("STRING", {
|
||||
"default": "ComfyUI",
|
||||
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc."
|
||||
}),
|
||||
"file_format": (["png", "jpeg", "webp"], {
|
||||
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"custom_prompt": ("STRING", {"default": "", "forceInput": True}),
|
||||
"lossless_webp": ("BOOLEAN", {"default": True}),
|
||||
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
|
||||
"embed_workflow": ("BOOLEAN", {"default": False}),
|
||||
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
|
||||
"lossless_webp": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss."
|
||||
}),
|
||||
"quality": ("INT", {
|
||||
"default": 100,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files."
|
||||
}),
|
||||
"embed_workflow": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats."
|
||||
}),
|
||||
"add_counter_to_filename": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images."
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
},
|
||||
@@ -51,31 +69,51 @@ class SaveImage:
|
||||
FUNCTION = "process_image"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
async def get_lora_hash(self, lora_name):
|
||||
def get_lora_hash(self, lora_name):
|
||||
"""Get the lora hash from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
return item.get('sha256')
|
||||
# Use the new direct filename lookup method
|
||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
return None
|
||||
|
||||
async def format_metadata(self, parsed_workflow, custom_prompt=None):
|
||||
def get_checkpoint_hash(self, checkpoint_path):
|
||||
"""Get the checkpoint hash from cache"""
|
||||
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Extract basename without extension
|
||||
checkpoint_name = os.path.basename(checkpoint_path)
|
||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||
|
||||
# Try direct filename lookup first
|
||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
return None
|
||||
|
||||
def format_metadata(self, metadata_dict):
|
||||
"""Format metadata in the requested format similar to userComment example"""
|
||||
if not parsed_workflow:
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
|
||||
# Extract the prompt and negative prompt
|
||||
prompt = parsed_workflow.get('prompt', '')
|
||||
negative_prompt = parsed_workflow.get('negative_prompt', '')
|
||||
# Helper function to only add parameter if value is not None
|
||||
def add_param_if_not_none(param_list, label, value):
|
||||
if value is not None:
|
||||
param_list.append(f"{label}: {value}")
|
||||
|
||||
# Override prompt with custom_prompt if provided
|
||||
if custom_prompt:
|
||||
prompt = custom_prompt
|
||||
# Extract the prompt and negative prompt
|
||||
prompt = metadata_dict.get('prompt', '')
|
||||
negative_prompt = metadata_dict.get('negative_prompt', '')
|
||||
|
||||
# Extract loras from the prompt if present
|
||||
loras_text = parsed_workflow.get('loras', '')
|
||||
loras_text = metadata_dict.get('loras', '')
|
||||
lora_hashes = {}
|
||||
|
||||
# If loras are found, add them on a new line after the prompt
|
||||
@@ -87,7 +125,7 @@ class SaveImage:
|
||||
|
||||
# Get hash for each lora
|
||||
for lora_name, strength in lora_matches:
|
||||
hash_value = await self.get_lora_hash(lora_name)
|
||||
hash_value = self.get_lora_hash(lora_name)
|
||||
if hash_value:
|
||||
lora_hashes[lora_name] = hash_value
|
||||
else:
|
||||
@@ -104,11 +142,15 @@ class SaveImage:
|
||||
params = []
|
||||
|
||||
# Add standard parameters in the correct order
|
||||
if 'steps' in parsed_workflow:
|
||||
params.append(f"Steps: {parsed_workflow.get('steps')}")
|
||||
if 'steps' in metadata_dict:
|
||||
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
|
||||
|
||||
if 'sampler' in parsed_workflow:
|
||||
sampler = parsed_workflow.get('sampler')
|
||||
# Combine sampler and scheduler information
|
||||
sampler_name = None
|
||||
scheduler_name = None
|
||||
|
||||
if 'sampler' in metadata_dict:
|
||||
sampler = metadata_dict.get('sampler')
|
||||
# Convert ComfyUI sampler names to user-friendly names
|
||||
sampler_mapping = {
|
||||
'euler': 'Euler',
|
||||
@@ -128,10 +170,9 @@ class SaveImage:
|
||||
'ddim': 'DDIM'
|
||||
}
|
||||
sampler_name = sampler_mapping.get(sampler, sampler)
|
||||
params.append(f"Sampler: {sampler_name}")
|
||||
|
||||
if 'scheduler' in parsed_workflow:
|
||||
scheduler = parsed_workflow.get('scheduler')
|
||||
if 'scheduler' in metadata_dict:
|
||||
scheduler = metadata_dict.get('scheduler')
|
||||
scheduler_mapping = {
|
||||
'normal': 'Simple',
|
||||
'karras': 'Karras',
|
||||
@@ -140,35 +181,54 @@ class SaveImage:
|
||||
'sgm_quadratic': 'SGM Quadratic'
|
||||
}
|
||||
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
||||
params.append(f"Schedule type: {scheduler_name}")
|
||||
|
||||
# CFG scale (cfg in parsed_workflow)
|
||||
if 'cfg_scale' in parsed_workflow:
|
||||
params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}")
|
||||
elif 'cfg' in parsed_workflow:
|
||||
params.append(f"CFG scale: {parsed_workflow.get('cfg')}")
|
||||
# Add combined sampler and scheduler information
|
||||
if sampler_name:
|
||||
if scheduler_name:
|
||||
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
||||
else:
|
||||
params.append(f"Sampler: {sampler_name}")
|
||||
|
||||
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
||||
if 'guidance' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
|
||||
elif 'cfg_scale' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
|
||||
elif 'cfg' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
|
||||
|
||||
# Seed
|
||||
if 'seed' in parsed_workflow:
|
||||
params.append(f"Seed: {parsed_workflow.get('seed')}")
|
||||
if 'seed' in metadata_dict:
|
||||
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
|
||||
|
||||
# Size
|
||||
if 'size' in parsed_workflow:
|
||||
params.append(f"Size: {parsed_workflow.get('size')}")
|
||||
if 'size' in metadata_dict:
|
||||
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
|
||||
|
||||
# Model info
|
||||
if 'checkpoint' in parsed_workflow:
|
||||
# Extract basename without path
|
||||
checkpoint = os.path.basename(parsed_workflow.get('checkpoint', ''))
|
||||
# Remove extension if present
|
||||
checkpoint = os.path.splitext(checkpoint)[0]
|
||||
params.append(f"Model: {checkpoint}")
|
||||
if 'checkpoint' in metadata_dict:
|
||||
# Ensure checkpoint is a string before processing
|
||||
checkpoint = metadata_dict.get('checkpoint')
|
||||
if checkpoint is not None:
|
||||
# Get model hash
|
||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||
|
||||
# Extract basename without path
|
||||
checkpoint_name = os.path.basename(checkpoint)
|
||||
# Remove extension if present
|
||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||
|
||||
# Add model hash if available
|
||||
if model_hash:
|
||||
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
||||
else:
|
||||
params.append(f"Model: {checkpoint_name}")
|
||||
|
||||
# Add LoRA hashes if available
|
||||
if lora_hashes:
|
||||
lora_hash_parts = []
|
||||
for lora_name, hash_value in lora_hashes.items():
|
||||
lora_hash_parts.append(f"{lora_name}: {hash_value}")
|
||||
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
|
||||
|
||||
if lora_hash_parts:
|
||||
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"")
|
||||
@@ -181,9 +241,9 @@ class SaveImage:
|
||||
|
||||
# credit to nkchocoai
|
||||
# Add format_filename method to handle pattern substitution
|
||||
def format_filename(self, filename, parsed_workflow):
|
||||
def format_filename(self, filename, metadata_dict):
|
||||
"""Format filename with metadata values"""
|
||||
if not parsed_workflow:
|
||||
if not metadata_dict:
|
||||
return filename
|
||||
|
||||
result = re.findall(self.pattern_format, filename)
|
||||
@@ -191,31 +251,37 @@ class SaveImage:
|
||||
parts = segment.replace("%", "").split(":")
|
||||
key = parts[0]
|
||||
|
||||
if key == "seed" and 'seed' in parsed_workflow:
|
||||
filename = filename.replace(segment, str(parsed_workflow.get('seed', '')))
|
||||
elif key == "width" and 'size' in parsed_workflow:
|
||||
size = parsed_workflow.get('size', 'x')
|
||||
if key == "seed" and 'seed' in metadata_dict:
|
||||
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
|
||||
elif key == "width" and 'size' in metadata_dict:
|
||||
size = metadata_dict.get('size', 'x')
|
||||
w = size.split('x')[0] if isinstance(size, str) else size[0]
|
||||
filename = filename.replace(segment, str(w))
|
||||
elif key == "height" and 'size' in parsed_workflow:
|
||||
size = parsed_workflow.get('size', 'x')
|
||||
elif key == "height" and 'size' in metadata_dict:
|
||||
size = metadata_dict.get('size', 'x')
|
||||
h = size.split('x')[1] if isinstance(size, str) else size[1]
|
||||
filename = filename.replace(segment, str(h))
|
||||
elif key == "pprompt" and 'prompt' in parsed_workflow:
|
||||
prompt = parsed_workflow.get('prompt', '').replace("\n", " ")
|
||||
elif key == "pprompt" and 'prompt' in metadata_dict:
|
||||
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
prompt = prompt[:length]
|
||||
filename = filename.replace(segment, prompt.strip())
|
||||
elif key == "nprompt" and 'negative_prompt' in parsed_workflow:
|
||||
prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ")
|
||||
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
|
||||
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
prompt = prompt[:length]
|
||||
filename = filename.replace(segment, prompt.strip())
|
||||
elif key == "model" and 'checkpoint' in parsed_workflow:
|
||||
model = parsed_workflow.get('checkpoint', '')
|
||||
model = os.path.splitext(os.path.basename(model))[0]
|
||||
elif key == "model":
|
||||
model_value = metadata_dict.get('checkpoint')
|
||||
if isinstance(model_value, (bytes, os.PathLike)):
|
||||
model_value = str(model_value)
|
||||
|
||||
if not isinstance(model_value, str) or not model_value:
|
||||
model = "model_unavailable"
|
||||
else:
|
||||
model = os.path.splitext(os.path.basename(model_value))[0]
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
model = model[:length]
|
||||
@@ -224,12 +290,13 @@ class SaveImage:
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
date_table = {
|
||||
"yyyy": str(now.year),
|
||||
"MM": str(now.month).zfill(2),
|
||||
"dd": str(now.day).zfill(2),
|
||||
"hh": str(now.hour).zfill(2),
|
||||
"mm": str(now.minute).zfill(2),
|
||||
"ss": str(now.second).zfill(2),
|
||||
"yyyy": f"{now.year:04d}",
|
||||
"yy": f"{now.year % 100:02d}",
|
||||
"MM": f"{now.month:02d}",
|
||||
"dd": f"{now.day:02d}",
|
||||
"hh": f"{now.hour:02d}",
|
||||
"mm": f"{now.minute:02d}",
|
||||
"ss": f"{now.second:02d}",
|
||||
}
|
||||
if len(parts) >= 2:
|
||||
date_format = parts[1]
|
||||
@@ -244,24 +311,19 @@ class SaveImage:
|
||||
|
||||
return filename
|
||||
|
||||
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
|
||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
|
||||
custom_prompt=None):
|
||||
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None,
|
||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
|
||||
"""Save images with metadata"""
|
||||
results = []
|
||||
|
||||
# Parse the workflow using the WorkflowParser
|
||||
parser = WorkflowParser()
|
||||
if prompt:
|
||||
parsed_workflow = parser.parse_workflow(prompt)
|
||||
else:
|
||||
parsed_workflow = {}
|
||||
|
||||
# Get metadata using the metadata collector
|
||||
raw_metadata = get_metadata()
|
||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||
|
||||
# Get or create metadata asynchronously
|
||||
metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt))
|
||||
metadata = self.format_metadata(metadata_dict)
|
||||
|
||||
# Process filename_prefix with pattern substitution
|
||||
filename_prefix = self.format_filename(filename_prefix, parsed_workflow)
|
||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||
|
||||
# Get initial save path info once for the batch
|
||||
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
|
||||
@@ -283,13 +345,14 @@ class SaveImage:
|
||||
if add_counter_to_filename:
|
||||
# Use counter + i to ensure unique filenames for all images in batch
|
||||
current_counter = counter + i
|
||||
base_filename += f"_{current_counter:05}"
|
||||
base_filename += f"_{current_counter:05}_"
|
||||
|
||||
# Set file extension and prepare saving parameters
|
||||
if file_format == "png":
|
||||
file = base_filename + ".png"
|
||||
file_extension = ".png"
|
||||
save_kwargs = {"optimize": True, "compress_level": self.compress_level}
|
||||
# Remove "optimize": True to match built-in node behavior
|
||||
save_kwargs = {"compress_level": self.compress_level}
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
elif file_format == "jpeg":
|
||||
file = base_filename + ".jpg"
|
||||
@@ -298,7 +361,8 @@ class SaveImage:
|
||||
elif file_format == "webp":
|
||||
file = base_filename + ".webp"
|
||||
file_extension = ".webp"
|
||||
save_kwargs = {"quality": quality, "lossless": lossless_webp}
|
||||
# Add optimization param to control performance
|
||||
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
|
||||
|
||||
# Full save path
|
||||
file_path = os.path.join(full_output_folder, file)
|
||||
@@ -324,14 +388,23 @@ class SaveImage:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
img.save(file_path, format="JPEG", **save_kwargs)
|
||||
elif file_format == "webp":
|
||||
# For WebP, also use piexif for metadata
|
||||
if metadata:
|
||||
try:
|
||||
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception as e:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
try:
|
||||
# For WebP, use piexif for metadata
|
||||
exif_dict = {}
|
||||
|
||||
if metadata:
|
||||
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}
|
||||
|
||||
# Add workflow if needed
|
||||
if embed_workflow and extra_pnginfo is not None:
|
||||
workflow_json = json.dumps(extra_pnginfo["workflow"])
|
||||
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json}
|
||||
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception as e:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
|
||||
img.save(file_path, format="WEBP", **save_kwargs)
|
||||
|
||||
results.append({
|
||||
@@ -345,31 +418,34 @@ class SaveImage:
|
||||
|
||||
return results
|
||||
|
||||
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
|
||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
|
||||
custom_prompt=""):
|
||||
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
|
||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
|
||||
"""Process and save image with metadata"""
|
||||
# Make sure the output directory exists
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
# If images is already a list or array of images, do nothing; otherwise, convert to list
|
||||
if isinstance(images, (list, np.ndarray)):
|
||||
pass
|
||||
else:
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
|
||||
# Save all images
|
||||
results = self.save_images(
|
||||
images,
|
||||
filename_prefix,
|
||||
file_format,
|
||||
id,
|
||||
prompt,
|
||||
extra_pnginfo,
|
||||
lossless_webp,
|
||||
quality,
|
||||
embed_workflow,
|
||||
add_counter_to_filename,
|
||||
custom_prompt if custom_prompt.strip() else None
|
||||
add_counter_to_filename
|
||||
)
|
||||
|
||||
return (images,)
|
||||
return (images,)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from server import PromptServer # type: ignore
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
import logging
|
||||
|
||||
@@ -16,11 +15,22 @@ class TriggerWordToggle:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"group_mode": ("BOOLEAN", {"default": True}),
|
||||
"group_mode": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "When enabled, treats each group of trigger words as a single toggleable unit."
|
||||
}),
|
||||
"default_active": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Sets the default initial state (active or inactive) when trigger words are added."
|
||||
}),
|
||||
"allow_strength_adjustment": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Enable mouse wheel adjustment of each trigger word's strength."
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID
|
||||
"id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -41,17 +51,18 @@ class TriggerWordToggle:
|
||||
else:
|
||||
return data
|
||||
|
||||
def process_trigger_words(self, id, group_mode, **kwargs):
|
||||
def process_trigger_words(
|
||||
self,
|
||||
id,
|
||||
group_mode,
|
||||
default_active,
|
||||
allow_strength_adjustment=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Handle both old and new formats for trigger_words
|
||||
trigger_words_data = self._get_toggle_data(kwargs, 'trigger_words')
|
||||
trigger_words_data = self._get_toggle_data(kwargs, 'orinalMessage')
|
||||
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||
|
||||
# Send trigger words to frontend
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": id,
|
||||
"message": trigger_words
|
||||
})
|
||||
|
||||
filtered_triggers = trigger_words
|
||||
|
||||
# Get toggle data with support for both formats
|
||||
@@ -63,27 +74,89 @@ class TriggerWordToggle:
|
||||
trigger_data = json.loads(trigger_data)
|
||||
|
||||
# Create dictionaries to track active state of words or groups
|
||||
active_state = {item['text']: item.get('active', False) for item in trigger_data}
|
||||
# Also track strength values for each trigger word
|
||||
active_state = {}
|
||||
strength_map = {}
|
||||
|
||||
if group_mode:
|
||||
# Split by two or more consecutive commas to get groups
|
||||
groups = re.split(r',{2,}', trigger_words)
|
||||
# Remove leading/trailing whitespace from each group
|
||||
groups = [group.strip() for group in groups]
|
||||
|
||||
# Filter groups: keep those not in toggle_trigger_words or those that are active
|
||||
filtered_groups = [group for group in groups if group not in active_state or active_state[group]]
|
||||
|
||||
if filtered_groups:
|
||||
filtered_triggers = ', '.join(filtered_groups)
|
||||
for item in trigger_data:
|
||||
text = item['text']
|
||||
active = item.get('active', False)
|
||||
# Extract strength if it's in the format "(word:strength)"
|
||||
strength_match = re.match(r'\((.+):([\d.]+)\)', text)
|
||||
if strength_match:
|
||||
original_word = strength_match.group(1).strip()
|
||||
strength = float(strength_match.group(2))
|
||||
active_state[original_word] = active
|
||||
if allow_strength_adjustment:
|
||||
strength_map[original_word] = strength
|
||||
else:
|
||||
filtered_triggers = ""
|
||||
active_state[text.strip()] = active
|
||||
|
||||
if group_mode:
|
||||
if isinstance(trigger_data, list):
|
||||
filtered_groups = []
|
||||
for item in trigger_data:
|
||||
text = (item.get('text') or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if item.get('active', False):
|
||||
filtered_groups.append(text)
|
||||
|
||||
if filtered_groups:
|
||||
filtered_triggers = ', '.join(filtered_groups)
|
||||
else:
|
||||
filtered_triggers = ""
|
||||
else:
|
||||
# Split by two or more consecutive commas to get groups
|
||||
groups = re.split(r',{2,}', trigger_words)
|
||||
# Remove leading/trailing whitespace from each group
|
||||
groups = [group.strip() for group in groups]
|
||||
|
||||
# Process groups: keep those not in toggle_trigger_words or those that are active
|
||||
filtered_groups = []
|
||||
for group in groups:
|
||||
# Check if this group contains any words that are in the active_state
|
||||
group_words = [word.strip() for word in group.split(',')]
|
||||
active_group_words = []
|
||||
|
||||
for word in group_words:
|
||||
word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip()
|
||||
|
||||
if word_comparison not in active_state or active_state[word_comparison]:
|
||||
active_group_words.append(
|
||||
self._format_word_output(
|
||||
word_comparison,
|
||||
strength_map,
|
||||
allow_strength_adjustment,
|
||||
)
|
||||
)
|
||||
|
||||
if active_group_words:
|
||||
filtered_groups.append(', '.join(active_group_words))
|
||||
|
||||
if filtered_groups:
|
||||
filtered_triggers = ', '.join(filtered_groups)
|
||||
else:
|
||||
filtered_triggers = ""
|
||||
else:
|
||||
# Original behavior for individual words mode
|
||||
# Normal mode: split by commas and treat each word as a separate tag
|
||||
original_words = [word.strip() for word in trigger_words.split(',')]
|
||||
# Filter out empty strings
|
||||
original_words = [word for word in original_words if word]
|
||||
filtered_words = [word for word in original_words if word not in active_state or active_state[word]]
|
||||
|
||||
filtered_words = []
|
||||
for word in original_words:
|
||||
# Remove any existing strength formatting for comparison
|
||||
word_comparison = re.sub(r'\((.+):([\d.]+)\)', r'\1', word).strip()
|
||||
|
||||
if word_comparison not in active_state or active_state[word_comparison]:
|
||||
filtered_words.append(
|
||||
self._format_word_output(
|
||||
word_comparison,
|
||||
strength_map,
|
||||
allow_strength_adjustment,
|
||||
)
|
||||
)
|
||||
|
||||
if filtered_words:
|
||||
filtered_triggers = ', '.join(filtered_words)
|
||||
@@ -93,4 +166,9 @@ class TriggerWordToggle:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing trigger words: {e}")
|
||||
|
||||
return (filtered_triggers,)
|
||||
return (filtered_triggers,)
|
||||
|
||||
def _format_word_output(self, base_word, strength_map, allow_strength_adjustment):
|
||||
if allow_strength_adjustment and base_word in strength_map:
|
||||
return f"({base_word}:{strength_map[base_word]:.2f})"
|
||||
return base_word
|
||||
|
||||
@@ -30,4 +30,105 @@ class FlexibleOptionalInputType(dict):
|
||||
return True
|
||||
|
||||
|
||||
any_type = AnyType("*")
|
||||
any_type = AnyType("*")
|
||||
|
||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def get_loras_list(kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
|
||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
if filter_prefix and not k.startswith(filter_prefix):
|
||||
continue
|
||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
def to_diffusers(input_lora):
|
||||
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
|
||||
if isinstance(input_lora, str):
|
||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||
else:
|
||||
tensors = {k: v for k, v in input_lora.items()}
|
||||
|
||||
# Convert FP8 tensors to BF16
|
||||
for k, v in tensors.items():
|
||||
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensors[k] = v.to(torch.bfloat16)
|
||||
|
||||
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||
|
||||
return new_tensors
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
"""Load a Flux LoRA for Nunchaku model"""
|
||||
model_wrapper = model.model.diffusion_model
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
|
||||
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
|
||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||
if not lora_path or not os.path.isfile(lora_path):
|
||||
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
||||
return model
|
||||
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
# Convert the LoRA to diffusers format
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
# Handle embedding adjustment if needed
|
||||
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||
assert new_in_channels % 4 == 0
|
||||
new_in_channels = new_in_channels // 4
|
||||
|
||||
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||
if old_in_channels < new_in_channels:
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
97
py/nodes/wanvideo_lora_select.py
Normal file
97
py/nodes/wanvideo_lora_select.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WanVideoLoraSelect:
|
||||
NAME = "WanVideo Lora Select (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"text": ("STRING", {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras"
|
||||
|
||||
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
# Process existing prev_lora if available
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_loras:
|
||||
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
|
||||
|
||||
# Get blocks if available
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_from_widget = get_loras_list(kwargs)
|
||||
for lora in loras_from_widget:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Create lora item for WanVideo format
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_loras,
|
||||
}
|
||||
|
||||
# Add to list and collect active loras
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format trigger_words for output
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format active_loras for output
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
125
py/nodes/wanvideo_lora_select_from_text.py
Normal file
125
py/nodes/wanvideo_lora_select_from_text.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import any_type
|
||||
import logging
|
||||
|
||||
# 初始化日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 定义新节点的类
|
||||
class WanVideoLoraSelectFromText:
|
||||
# 节点在UI中显示的名称
|
||||
NAME = "WanVideo Lora Select From Text (LoraManager)"
|
||||
# 节点所属的分类
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_lora": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"lora_syntax": ("STRING", {
|
||||
"multiline": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
|
||||
"optional": {
|
||||
"prev_lora": ("WANVIDLORA",),
|
||||
"blocks": ("BLOCKS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", "STRING", "STRING")
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
|
||||
FUNCTION = "process_loras_from_syntax"
|
||||
|
||||
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
|
||||
text_to_process = lora_syntax
|
||||
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_lora:
|
||||
low_mem_load = False
|
||||
|
||||
parts = text_to_process.split('<lora:')
|
||||
for part in parts[1:]:
|
||||
end_index = part.find('>')
|
||||
if end_index == -1:
|
||||
continue
|
||||
|
||||
content = part[:end_index]
|
||||
lora_parts = content.split(':')
|
||||
|
||||
lora_name_raw = ""
|
||||
model_strength = 1.0
|
||||
clip_strength = 1.0
|
||||
|
||||
if len(lora_parts) == 2:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = model_strength
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strength for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
elif len(lora_parts) >= 3:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = float(lora_parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strengths for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_path, trigger_words = get_lora_info(lora_name_raw)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_lora,
|
||||
}
|
||||
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name_raw, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": "WanVideo Lora Select From Text (LoraManager)"
|
||||
}
|
||||
24
py/recipes/__init__.py
Normal file
24
py/recipes/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Recipe metadata parser package for ComfyUI-Lora-Manager."""
|
||||
|
||||
from .base import RecipeMetadataParser
|
||||
from .factory import RecipeParserFactory
|
||||
from .constants import GEN_PARAM_KEYS, VALID_LORA_TYPES
|
||||
from .parsers import (
|
||||
RecipeFormatParser,
|
||||
ComfyMetadataParser,
|
||||
MetaFormatParser,
|
||||
AutomaticMetadataParser,
|
||||
CivitaiApiMetadataParser
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'RecipeMetadataParser',
|
||||
'RecipeParserFactory',
|
||||
'GEN_PARAM_KEYS',
|
||||
'VALID_LORA_TYPES',
|
||||
'RecipeFormatParser',
|
||||
'ComfyMetadataParser',
|
||||
'MetaFormatParser',
|
||||
'AutomaticMetadataParser',
|
||||
'CivitaiApiMetadataParser'
|
||||
]
|
||||
214
py/recipes/base.py
Normal file
214
py/recipes/base.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Base classes for recipe parsers."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from ..config import config
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeMetadataParser(ABC):
|
||||
"""Interface for parsing recipe metadata from image user comments"""
|
||||
|
||||
METADATA_MARKER = None
|
||||
|
||||
@abstractmethod
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the metadata format"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse metadata from user comment and return structured recipe data
|
||||
|
||||
Args:
|
||||
user_comment: The EXIF UserComment string from the image
|
||||
recipe_scanner: Optional recipe scanner instance for local LoRA lookup
|
||||
civitai_client: Optional Civitai client for fetching model information
|
||||
|
||||
Returns:
|
||||
Dict containing parsed recipe data with standardized format
|
||||
"""
|
||||
pass
|
||||
|
||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Populate a lora entry with information from Civitai API response
|
||||
|
||||
Args:
|
||||
lora_entry: The lora entry to populate
|
||||
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
|
||||
recipe_scanner: Optional recipe scanner for local file lookup
|
||||
base_model_counts: Optional dict to track base model counts
|
||||
hash_value: Optional hash value to use if not available in civitai_info
|
||||
|
||||
Returns:
|
||||
The populated lora_entry dict if type is valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Unpack the tuple to get the actual data
|
||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
if not civitai_info or error_msg == "Model not found":
|
||||
# Model not found or deleted
|
||||
lora_entry['isDeleted'] = True
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
return lora_entry
|
||||
|
||||
# Get model type and validate
|
||||
model_type = civitai_info.get('model', {}).get('type', '').lower()
|
||||
lora_entry['type'] = model_type
|
||||
if model_type not in VALID_LORA_TYPES:
|
||||
logger.debug(f"Skipping non-LoRA model type: {model_type}")
|
||||
return None
|
||||
|
||||
# Check if this is an early access lora
|
||||
if civitai_info.get('earlyAccessEndsAt'):
|
||||
# Convert earlyAccessEndsAt to a human-readable date
|
||||
early_access_date = civitai_info.get('earlyAccessEndsAt', '')
|
||||
lora_entry['isEarlyAccess'] = True
|
||||
lora_entry['earlyAccessEndsAt'] = early_access_date
|
||||
|
||||
# Update model name if available
|
||||
if 'model' in civitai_info and 'name' in civitai_info['model']:
|
||||
lora_entry['name'] = civitai_info['model']['name']
|
||||
|
||||
lora_entry['id'] = civitai_info.get('id')
|
||||
lora_entry['modelId'] = civitai_info.get('modelId')
|
||||
|
||||
# Update version if available
|
||||
if 'name' in civitai_info:
|
||||
lora_entry['version'] = civitai_info.get('name', '')
|
||||
|
||||
# Get thumbnail URL from first image
|
||||
if 'images' in civitai_info and civitai_info['images']:
|
||||
image_url = civitai_info['images'][0].get('url')
|
||||
if image_url:
|
||||
rewritten_image_url, _ = rewrite_preview_url(image_url, media_type='image')
|
||||
lora_entry['thumbnailUrl'] = rewritten_image_url or image_url
|
||||
|
||||
# Get base model
|
||||
current_base_model = civitai_info.get('baseModel', '')
|
||||
lora_entry['baseModel'] = current_base_model
|
||||
|
||||
# Update base model counts if tracking them
|
||||
if base_model_counts is not None and current_base_model:
|
||||
base_model_counts[current_base_model] = base_model_counts.get(current_base_model, 0) + 1
|
||||
|
||||
# Get download URL
|
||||
lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '')
|
||||
|
||||
# Process file information if available
|
||||
if 'files' in civitai_info:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in civitai_info.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
if model_file:
|
||||
# Get size
|
||||
lora_entry['size'] = model_file.get('sizeKB', 0) * 1024
|
||||
|
||||
# Get SHA256 hash
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256', hash_value)
|
||||
if sha256:
|
||||
lora_entry['hash'] = sha256.lower()
|
||||
|
||||
# Check if exists locally
|
||||
if recipe_scanner and lora_entry['hash']:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_hash(lora_entry['hash'])
|
||||
if exists_locally:
|
||||
try:
|
||||
local_path = lora_scanner.get_path_by_hash(lora_entry['hash'])
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['localPath'] = local_path
|
||||
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]
|
||||
|
||||
# Get thumbnail from local preview if available
|
||||
lora_cache = await lora_scanner.get_cached_data()
|
||||
lora_item = next((item for item in lora_cache.raw_data
|
||||
if item['sha256'].lower() == lora_entry['hash'].lower()), None)
|
||||
if lora_item and 'preview_url' in lora_item:
|
||||
lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting local lora path: {e}")
|
||||
else:
|
||||
# For missing LoRAs, get file_name from model_file.name
|
||||
file_name = model_file.get('name', '')
|
||||
lora_entry['file_name'] = os.path.splitext(file_name)[0] if file_name else ''
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error populating lora from Civitai info: {e}")
|
||||
|
||||
return lora_entry
|
||||
|
||||
async def populate_checkpoint_from_civitai(self, checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Populate checkpoint information from Civitai API response
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint entry to populate
|
||||
civitai_info: The response from Civitai API or a (data, error_msg) tuple
|
||||
|
||||
Returns:
|
||||
The populated checkpoint dict
|
||||
"""
|
||||
try:
|
||||
civitai_data, error_msg = (
|
||||
(civitai_info, None)
|
||||
if not isinstance(civitai_info, tuple)
|
||||
else civitai_info
|
||||
)
|
||||
|
||||
if not civitai_data or error_msg == "Model not found":
|
||||
checkpoint['isDeleted'] = True
|
||||
return checkpoint
|
||||
|
||||
if 'model' in civitai_data and 'name' in civitai_data['model']:
|
||||
checkpoint['name'] = civitai_data['model']['name']
|
||||
|
||||
if 'name' in civitai_data:
|
||||
checkpoint['version'] = civitai_data.get('name', '')
|
||||
|
||||
if 'images' in civitai_data and civitai_data['images']:
|
||||
image_url = civitai_data['images'][0].get('url')
|
||||
if image_url:
|
||||
rewritten_image_url, _ = rewrite_preview_url(image_url, media_type='image')
|
||||
checkpoint['thumbnailUrl'] = rewritten_image_url or image_url
|
||||
|
||||
checkpoint['baseModel'] = civitai_data.get('baseModel', '')
|
||||
checkpoint['downloadUrl'] = civitai_data.get('downloadUrl', '')
|
||||
|
||||
checkpoint['modelId'] = civitai_data.get('modelId', checkpoint.get('modelId', 0))
|
||||
|
||||
if 'files' in civitai_data:
|
||||
model_file = next(
|
||||
(
|
||||
file
|
||||
for file in civitai_data.get('files', [])
|
||||
if file.get('type') == 'Model'
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if model_file:
|
||||
checkpoint['size'] = model_file.get('sizeKB', 0) * 1024
|
||||
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
checkpoint['hash'] = sha256.lower()
|
||||
|
||||
file_name = model_file.get('name', '')
|
||||
if file_name:
|
||||
checkpoint['file_name'] = os.path.splitext(file_name)[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error populating checkpoint from Civitai info: {e}")
|
||||
|
||||
return checkpoint
|
||||
16
py/recipes/constants.py
Normal file
16
py/recipes/constants.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Constants used across recipe parsers."""
|
||||
|
||||
# Import VALID_LORA_TYPES from utils.constants
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
|
||||
# Constants for generation parameters
|
||||
GEN_PARAM_KEYS = [
|
||||
'prompt',
|
||||
'negative_prompt',
|
||||
'steps',
|
||||
'sampler',
|
||||
'cfg_scale',
|
||||
'seed',
|
||||
'size',
|
||||
'clip_skip',
|
||||
]
|
||||
64
py/recipes/factory.py
Normal file
64
py/recipes/factory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Factory for creating recipe metadata parsers."""
|
||||
|
||||
import logging
|
||||
from .parsers import (
|
||||
RecipeFormatParser,
|
||||
ComfyMetadataParser,
|
||||
MetaFormatParser,
|
||||
AutomaticMetadataParser,
|
||||
CivitaiApiMetadataParser
|
||||
)
|
||||
from .base import RecipeMetadataParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeParserFactory:
|
||||
"""Factory for creating recipe metadata parsers"""
|
||||
|
||||
@staticmethod
|
||||
def create_parser(metadata) -> RecipeMetadataParser:
|
||||
"""
|
||||
Create appropriate parser based on the metadata content
|
||||
|
||||
Args:
|
||||
metadata: The metadata from the image (dict or str)
|
||||
|
||||
Returns:
|
||||
Appropriate RecipeMetadataParser implementation
|
||||
"""
|
||||
# First, try CivitaiApiMetadataParser for dict input
|
||||
if isinstance(metadata, dict):
|
||||
try:
|
||||
if CivitaiApiMetadataParser().is_metadata_matching(metadata):
|
||||
return CivitaiApiMetadataParser()
|
||||
except Exception as e:
|
||||
logger.debug(f"CivitaiApiMetadataParser check failed: {e}")
|
||||
pass
|
||||
|
||||
# Convert dict to string for other parsers that expect string input
|
||||
try:
|
||||
import json
|
||||
metadata_str = json.dumps(metadata)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to convert dict to JSON string: {e}")
|
||||
return None
|
||||
else:
|
||||
metadata_str = metadata
|
||||
|
||||
# Try ComfyMetadataParser which requires valid JSON
|
||||
try:
|
||||
if ComfyMetadataParser().is_metadata_matching(metadata_str):
|
||||
return ComfyMetadataParser()
|
||||
except Exception:
|
||||
# If JSON parsing fails, move on to other parsers
|
||||
pass
|
||||
|
||||
# Check other parsers that expect string input
|
||||
if RecipeFormatParser().is_metadata_matching(metadata_str):
|
||||
return RecipeFormatParser()
|
||||
elif AutomaticMetadataParser().is_metadata_matching(metadata_str):
|
||||
return AutomaticMetadataParser()
|
||||
elif MetaFormatParser().is_metadata_matching(metadata_str):
|
||||
return MetaFormatParser()
|
||||
else:
|
||||
return None
|
||||
15
py/recipes/parsers/__init__.py
Normal file
15
py/recipes/parsers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Recipe parsers package."""
|
||||
|
||||
from .recipe_format import RecipeFormatParser
|
||||
from .comfy import ComfyMetadataParser
|
||||
from .meta_format import MetaFormatParser
|
||||
from .automatic import AutomaticMetadataParser
|
||||
from .civitai_image import CivitaiApiMetadataParser
|
||||
|
||||
__all__ = [
|
||||
'RecipeFormatParser',
|
||||
'ComfyMetadataParser',
|
||||
'MetaFormatParser',
|
||||
'AutomaticMetadataParser',
|
||||
'CivitaiApiMetadataParser',
|
||||
]
|
||||
441
py/recipes/parsers/automatic.py
Normal file
441
py/recipes/parsers/automatic.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""Parser for Automatic1111 metadata format."""
|
||||
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
"""Parser for Automatic1111 metadata format"""
|
||||
|
||||
METADATA_MARKER = r"Steps: \d+"
|
||||
|
||||
# Regular expressions for extracting specific metadata
|
||||
HASHES_REGEX = r', Hashes:\s*({[^}]+})'
|
||||
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
||||
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
||||
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
|
||||
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
||||
MODEL_NAME_PATTERN = r'Model: ([^,]+)'
|
||||
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the Automatic1111 format"""
|
||||
return re.search(self.METADATA_MARKER, user_comment) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Automatic1111 format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Split on Negative prompt if it exists
|
||||
if "Negative prompt:" in user_comment:
|
||||
parts = user_comment.split('Negative prompt:', 1)
|
||||
prompt = parts[0].strip()
|
||||
negative_and_params = parts[1] if len(parts) > 1 else ""
|
||||
else:
|
||||
# No negative prompt section
|
||||
param_start = re.search(self.METADATA_MARKER, user_comment)
|
||||
if param_start:
|
||||
prompt = user_comment[:param_start.start()].strip()
|
||||
negative_and_params = user_comment[param_start.start():]
|
||||
else:
|
||||
prompt = user_comment.strip()
|
||||
negative_and_params = ""
|
||||
|
||||
# Initialize metadata
|
||||
metadata = {
|
||||
"prompt": prompt,
|
||||
"loras": []
|
||||
}
|
||||
|
||||
# Extract negative prompt and parameters
|
||||
if negative_and_params:
|
||||
# If we split on "Negative prompt:", check for params section
|
||||
if "Negative prompt:" in user_comment:
|
||||
param_start = re.search(r'Steps: ', negative_and_params)
|
||||
if param_start:
|
||||
neg_prompt = negative_and_params[:param_start.start()].strip()
|
||||
metadata["negative_prompt"] = neg_prompt
|
||||
params_section = negative_and_params[param_start.start():]
|
||||
else:
|
||||
metadata["negative_prompt"] = negative_and_params.strip()
|
||||
params_section = ""
|
||||
else:
|
||||
# No negative prompt, entire section is params
|
||||
params_section = negative_and_params
|
||||
|
||||
# Extract generation parameters
|
||||
if params_section:
|
||||
# Extract Civitai resources
|
||||
civitai_resources_match = re.search(self.CIVITAI_RESOURCES_REGEX, params_section)
|
||||
if civitai_resources_match:
|
||||
try:
|
||||
civitai_resources = json.loads(civitai_resources_match.group(1))
|
||||
metadata["civitai_resources"] = civitai_resources
|
||||
params_section = params_section.replace(civitai_resources_match.group(0), '')
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error parsing Civitai resources JSON")
|
||||
|
||||
# Extract Hashes
|
||||
hashes_match = re.search(self.HASHES_REGEX, params_section)
|
||||
if hashes_match:
|
||||
try:
|
||||
hashes = json.loads(hashes_match.group(1))
|
||||
# Process hash keys
|
||||
processed_hashes = {}
|
||||
for key, value in hashes.items():
|
||||
# Convert Model: or LORA: prefix to lowercase if present
|
||||
if ':' in key:
|
||||
prefix, name = key.split(':', 1)
|
||||
prefix = prefix.lower()
|
||||
else:
|
||||
prefix = ''
|
||||
name = key
|
||||
|
||||
# Clean up the name part
|
||||
if '/' in name:
|
||||
name = name.split('/')[-1] # Get last part after /
|
||||
if '.safetensors' in name:
|
||||
name = name.split('.safetensors')[0] # Remove .safetensors
|
||||
|
||||
# Reconstruct the key
|
||||
new_key = f"{prefix}:{name}" if prefix else name
|
||||
processed_hashes[new_key] = value
|
||||
|
||||
metadata["hashes"] = processed_hashes
|
||||
# Remove hashes from params section to not interfere with other parsing
|
||||
params_section = params_section.replace(hashes_match.group(0), '')
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error parsing hashes JSON")
|
||||
|
||||
# Pick up model hash from parsed hashes if available
|
||||
if "hashes" in metadata and not metadata.get("model_hash"):
|
||||
model_hash_from_hashes = metadata["hashes"].get("model")
|
||||
if model_hash_from_hashes:
|
||||
metadata["model_hash"] = model_hash_from_hashes
|
||||
|
||||
# Extract Lora hashes in alternative format
|
||||
lora_hashes_match = re.search(self.LORA_HASHES_REGEX, params_section)
|
||||
if not hashes_match and lora_hashes_match:
|
||||
try:
|
||||
lora_hashes_str = lora_hashes_match.group(1)
|
||||
lora_hash_entries = lora_hashes_str.split(', ')
|
||||
|
||||
# Initialize hashes dict if it doesn't exist
|
||||
if "hashes" not in metadata:
|
||||
metadata["hashes"] = {}
|
||||
|
||||
# Parse each lora hash entry (format: "name: hash")
|
||||
for entry in lora_hash_entries:
|
||||
if ': ' in entry:
|
||||
lora_name, lora_hash = entry.split(': ', 1)
|
||||
# Add as lora type in the same format as regular hashes
|
||||
metadata["hashes"][f"lora:{lora_name}"] = lora_hash.strip()
|
||||
|
||||
# Remove lora hashes from params section
|
||||
params_section = params_section.replace(lora_hashes_match.group(0), '')
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Lora hashes: {e}")
|
||||
|
||||
# Extract checkpoint model hash/name when provided outside Civitai resources
|
||||
model_hash_match = re.search(self.MODEL_HASH_PATTERN, params_section)
|
||||
if model_hash_match:
|
||||
metadata["model_hash"] = model_hash_match.group(1).strip()
|
||||
params_section = params_section.replace(model_hash_match.group(0), '')
|
||||
|
||||
model_name_match = re.search(self.MODEL_NAME_PATTERN, params_section)
|
||||
if model_name_match:
|
||||
metadata["model_name"] = model_name_match.group(1).strip()
|
||||
params_section = params_section.replace(model_name_match.group(0), '')
|
||||
|
||||
# Extract basic parameters
|
||||
param_pattern = r'([A-Za-z\s]+): ([^,]+)'
|
||||
params = re.findall(param_pattern, params_section)
|
||||
gen_params = {}
|
||||
|
||||
for key, value in params:
|
||||
clean_key = key.strip().lower().replace(' ', '_')
|
||||
|
||||
# Skip if not in recognized gen param keys
|
||||
if clean_key not in GEN_PARAM_KEYS:
|
||||
continue
|
||||
|
||||
# Convert numeric values
|
||||
if clean_key in ['steps', 'seed']:
|
||||
try:
|
||||
gen_params[clean_key] = int(value.strip())
|
||||
except ValueError:
|
||||
gen_params[clean_key] = value.strip()
|
||||
elif clean_key in ['cfg_scale']:
|
||||
try:
|
||||
gen_params[clean_key] = float(value.strip())
|
||||
except ValueError:
|
||||
gen_params[clean_key] = value.strip()
|
||||
else:
|
||||
gen_params[clean_key] = value.strip()
|
||||
|
||||
# Extract size if available and add to gen_params if a recognized key
|
||||
size_match = re.search(r'Size: (\d+)x(\d+)', params_section)
|
||||
if size_match and 'size' in GEN_PARAM_KEYS:
|
||||
width, height = size_match.groups()
|
||||
gen_params['size'] = f"{width}x{height}"
|
||||
|
||||
# Add prompt and negative_prompt to gen_params if they're in GEN_PARAM_KEYS
|
||||
if 'prompt' in GEN_PARAM_KEYS and 'prompt' in metadata:
|
||||
gen_params['prompt'] = metadata['prompt']
|
||||
if 'negative_prompt' in GEN_PARAM_KEYS and 'negative_prompt' in metadata:
|
||||
gen_params['negative_prompt'] = metadata['negative_prompt']
|
||||
|
||||
metadata["gen_params"] = gen_params
|
||||
|
||||
# Extract LoRA and checkpoint information
|
||||
loras = []
|
||||
base_model_counts = {}
|
||||
checkpoint = None
|
||||
|
||||
# First use Civitai resources if available (more reliable source)
|
||||
if metadata.get("civitai_resources"):
|
||||
for resource in metadata.get("civitai_resources", []):
|
||||
# --- Added: Parse 'air' field if present ---
|
||||
air = resource.get("air")
|
||||
if air:
|
||||
# Format: urn:air:sdxl:lora:civitai:1221007@1375651
|
||||
# Or: urn:air:sdxl:checkpoint:civitai:623891@2019115
|
||||
air_pattern = r"urn:air:[^:]+:(?P<type>[^:]+):civitai:(?P<modelId>\d+)@(?P<modelVersionId>\d+)"
|
||||
air_match = re.match(air_pattern, air)
|
||||
if air_match:
|
||||
air_type = air_match.group("type")
|
||||
air_modelId = int(air_match.group("modelId"))
|
||||
air_modelVersionId = int(air_match.group("modelVersionId"))
|
||||
# checkpoint/lycoris/lora/hypernet
|
||||
resource["type"] = air_type
|
||||
resource["modelId"] = air_modelId
|
||||
resource["modelVersionId"] = air_modelVersionId
|
||||
# --- End added ---
|
||||
|
||||
if resource.get("type") == "checkpoint" and resource.get("modelVersionId"):
|
||||
version_id = resource.get("modelVersionId")
|
||||
version_id_str = str(version_id)
|
||||
checkpoint_entry = {
|
||||
'id': version_id,
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown Checkpoint"),
|
||||
'version': resource.get("modelVersionName", resource.get("versionName", "")),
|
||||
'type': resource.get("type", "checkpoint"),
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("modelName", ""),
|
||||
'hash': resource.get("hash", "") or "",
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id_str)
|
||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
||||
checkpoint_entry,
|
||||
civitai_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error fetching Civitai info for checkpoint version %s: %s",
|
||||
version_id,
|
||||
e,
|
||||
)
|
||||
|
||||
# Prefer the first checkpoint found
|
||||
if checkpoint_entry.get("baseModel"):
|
||||
base_model_value = checkpoint_entry["baseModel"]
|
||||
base_model_counts[base_model_value] = base_model_counts.get(base_model_value, 0) + 1
|
||||
|
||||
if checkpoint is None:
|
||||
checkpoint = checkpoint_entry
|
||||
|
||||
continue
|
||||
|
||||
if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"):
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", resource.get("versionName", "")),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get additional info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_version_info(resource.get("modelVersionId"))
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA {lora_entry['name']}: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Fallback checkpoint parsing from generic "Model" and "Model hash" fields
|
||||
if checkpoint is None:
|
||||
model_hash = metadata.get("model_hash")
|
||||
if not model_hash and metadata.get("hashes"):
|
||||
model_hash = metadata["hashes"].get("model")
|
||||
|
||||
model_name = metadata.get("model_name")
|
||||
file_name = ""
|
||||
if model_name:
|
||||
cleaned_name = re.split(r"[\\\\/]", model_name)[-1]
|
||||
file_name = os.path.splitext(cleaned_name)[0]
|
||||
|
||||
if model_hash or model_name:
|
||||
checkpoint_entry = {
|
||||
'id': 0,
|
||||
'modelId': 0,
|
||||
'name': model_name or "Unknown Checkpoint",
|
||||
'version': '',
|
||||
'type': 'checkpoint',
|
||||
'hash': model_hash or "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': file_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
if metadata_provider and model_hash:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(model_hash)
|
||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
||||
checkpoint_entry,
|
||||
civitai_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for checkpoint hash {model_hash}: {e}")
|
||||
|
||||
if checkpoint_entry.get("baseModel"):
|
||||
base_model_value = checkpoint_entry["baseModel"]
|
||||
base_model_counts[base_model_value] = base_model_counts.get(base_model_value, 0) + 1
|
||||
|
||||
checkpoint = checkpoint_entry
|
||||
|
||||
# If no LoRAs from Civitai resources or to supplement, extract from metadata["hashes"]
|
||||
if not loras or len(loras) == 0:
|
||||
# Extract lora weights from extranet tags in prompt (for later use)
|
||||
lora_weights = {}
|
||||
lora_matches = re.findall(self.EXTRANETS_REGEX, prompt)
|
||||
for lora_type, lora_name, lora_weight in lora_matches:
|
||||
key = f"{lora_type}:{lora_name}"
|
||||
lora_weights[key] = round(float(lora_weight), 2)
|
||||
|
||||
# Use hashes from metadata as the primary source
|
||||
if metadata.get("hashes"):
|
||||
for hash_key, lora_hash in metadata.get("hashes", {}).items():
|
||||
# Only process lora or hypernet types
|
||||
if not hash_key.startswith(("lora:", "hypernet:")):
|
||||
continue
|
||||
|
||||
lora_type, lora_name = hash_key.split(':', 1)
|
||||
|
||||
# Get weight from extranet tags if available, else default to 1.0
|
||||
weight = lora_weights.get(hash_key, 1.0)
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': lora_type, # 'lora' or 'hypernet'
|
||||
'weight': weight,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
if lora_hash:
|
||||
# If we have hash, use it for lookup
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
else:
|
||||
civitai_info = None
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA {lora_name}: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Try to get base model from resources or make educated guess
|
||||
base_model = None
|
||||
if checkpoint and checkpoint.get("baseModel"):
|
||||
base_model = checkpoint.get("baseModel")
|
||||
elif base_model_counts:
|
||||
# Use the most common base model from the loras
|
||||
base_model = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
# Prepare final result structure
|
||||
# Make sure gen_params only contains recognized keys
|
||||
filtered_gen_params = {}
|
||||
for key in GEN_PARAM_KEYS:
|
||||
if key in metadata.get("gen_params", {}):
|
||||
filtered_gen_params[key] = metadata["gen_params"][key]
|
||||
|
||||
result = {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'gen_params': filtered_gen_params,
|
||||
'from_automatic_metadata': True
|
||||
}
|
||||
|
||||
if checkpoint:
|
||||
result['checkpoint'] = checkpoint
|
||||
result['model'] = checkpoint
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Automatic1111 metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
500
py/recipes/parsers/civitai_image.py
Normal file
500
py/recipes/parsers/civitai_image.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""Parser for Civitai image metadata format."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Union
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
"""Parser for Civitai image metadata format"""
|
||||
|
||||
def is_metadata_matching(self, metadata) -> bool:
|
||||
"""Check if the metadata matches the Civitai image metadata format
|
||||
|
||||
Args:
|
||||
metadata: The metadata from the image (dict)
|
||||
|
||||
Returns:
|
||||
bool: True if this parser can handle the metadata
|
||||
"""
|
||||
if not metadata or not isinstance(metadata, dict):
|
||||
return False
|
||||
|
||||
def has_markers(payload: Dict[str, Any]) -> bool:
|
||||
# Check for common CivitAI image metadata fields
|
||||
civitai_image_fields = (
|
||||
"resources",
|
||||
"civitaiResources",
|
||||
"additionalResources",
|
||||
"hashes",
|
||||
"prompt",
|
||||
"negativePrompt",
|
||||
"steps",
|
||||
"sampler",
|
||||
"cfgScale",
|
||||
"seed",
|
||||
"width",
|
||||
"height",
|
||||
"Model",
|
||||
"Model hash"
|
||||
)
|
||||
return any(key in payload for key in civitai_image_fields)
|
||||
|
||||
# Check the main metadata object
|
||||
if has_markers(metadata):
|
||||
return True
|
||||
|
||||
# Check for LoRA hash patterns
|
||||
hashes = metadata.get("hashes")
|
||||
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes):
|
||||
return True
|
||||
|
||||
# Check nested meta object (common in CivitAI image responses)
|
||||
nested_meta = metadata.get("meta")
|
||||
if isinstance(nested_meta, dict):
|
||||
if has_markers(nested_meta):
|
||||
return True
|
||||
|
||||
# Also check for LoRA hash patterns in nested meta
|
||||
hashes = nested_meta.get("hashes")
|
||||
if isinstance(hashes, dict) and any(str(key).lower().startswith("lora:") for key in hashes):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Civitai image format
|
||||
|
||||
Args:
|
||||
metadata: The metadata from the image (dict)
|
||||
recipe_scanner: Optional recipe scanner service
|
||||
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
|
||||
|
||||
Returns:
|
||||
Dict containing parsed recipe data
|
||||
"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Civitai image responses may wrap the actual metadata inside a "meta" key
|
||||
if (
|
||||
isinstance(metadata, dict)
|
||||
and "meta" in metadata
|
||||
and isinstance(metadata["meta"], dict)
|
||||
):
|
||||
inner_meta = metadata["meta"]
|
||||
if any(
|
||||
key in inner_meta
|
||||
for key in (
|
||||
"resources",
|
||||
"civitaiResources",
|
||||
"additionalResources",
|
||||
"hashes",
|
||||
"prompt",
|
||||
"negativePrompt",
|
||||
)
|
||||
):
|
||||
metadata = inner_meta
|
||||
|
||||
# Initialize result structure
|
||||
result = {
|
||||
'base_model': None,
|
||||
'loras': [],
|
||||
'model': None,
|
||||
'gen_params': {},
|
||||
'from_civitai_image': True
|
||||
}
|
||||
|
||||
# Track already added LoRAs to prevent duplicates
|
||||
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
|
||||
|
||||
# Extract hash information from hashes field for LoRA matching
|
||||
lora_hashes = {}
|
||||
if "hashes" in metadata and isinstance(metadata["hashes"], dict):
|
||||
for key, hash_value in metadata["hashes"].items():
|
||||
key_str = str(key)
|
||||
if key_str.lower().startswith("lora:"):
|
||||
lora_name = key_str.split(":", 1)[1]
|
||||
lora_hashes[lora_name] = hash_value
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
if "prompt" in metadata:
|
||||
result["gen_params"]["prompt"] = metadata["prompt"]
|
||||
|
||||
if "negativePrompt" in metadata:
|
||||
result["gen_params"]["negative_prompt"] = metadata["negativePrompt"]
|
||||
|
||||
# Extract other generation parameters
|
||||
param_mapping = {
|
||||
"steps": "steps",
|
||||
"sampler": "sampler",
|
||||
"cfgScale": "cfg_scale",
|
||||
"seed": "seed",
|
||||
"Size": "size",
|
||||
"clipSkip": "clip_skip",
|
||||
}
|
||||
|
||||
for civitai_key, our_key in param_mapping.items():
|
||||
if civitai_key in metadata and our_key in GEN_PARAM_KEYS:
|
||||
result["gen_params"][our_key] = metadata[civitai_key]
|
||||
|
||||
# Extract base model information - directly if available
|
||||
if "baseModel" in metadata:
|
||||
result["base_model"] = metadata["baseModel"]
|
||||
elif "Model hash" in metadata and metadata_provider:
|
||||
model_hash = metadata["Model hash"]
|
||||
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
||||
# Try to find base model in resources
|
||||
for resource in metadata.get("resources", []):
|
||||
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
||||
# This is likely the checkpoint model
|
||||
if metadata_provider and resource.get("hash"):
|
||||
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
|
||||
base_model_counts = {}
|
||||
|
||||
# Process standard resources array
|
||||
if "resources" in metadata and isinstance(metadata["resources"], list):
|
||||
for resource in metadata["resources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type", "lora") == "lora":
|
||||
lora_hash = resource.get("hash", "")
|
||||
|
||||
# Try to get hash from the hashes field if not present in resource
|
||||
if not lora_hash and resource.get("name"):
|
||||
lora_hash = lora_hashes.get(resource["name"], "")
|
||||
|
||||
# Skip LoRAs without proper identification (hash or modelVersionId)
|
||||
if not lora_hash and not resource.get("modelVersionId"):
|
||||
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
|
||||
continue
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': resource.get("name", "Unknown LoRA"),
|
||||
'type': "lora",
|
||||
'weight': float(resource.get("weight", 1.0)),
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("name", "Unknown"),
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process civitaiResources array
|
||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
||||
for resource in metadata["civitaiResources"]:
|
||||
# Get resource type and identifier
|
||||
resource_type = str(resource.get("type") or "").lower()
|
||||
version_id = str(resource.get("modelVersionId", ""))
|
||||
|
||||
if resource_type == "checkpoint":
|
||||
checkpoint_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown Checkpoint"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "checkpoint"),
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("modelName", ""),
|
||||
'hash': resource.get("hash", "") or "",
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
||||
checkpoint_entry,
|
||||
civitai_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for checkpoint version {version_id}: {e}")
|
||||
|
||||
if result["model"] is None:
|
||||
result["model"] = checkpoint_entry
|
||||
|
||||
continue
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id and version_id in added_loras:
|
||||
continue
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
||||
|
||||
# Track this LoRA in our deduplication dict
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Skip resources that aren't LoRAs or LyCORIS
|
||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||
continue
|
||||
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
|
||||
# Extract ID from URN format if available
|
||||
version_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
version_id = parts[1]
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a version ID and metadata provider, try to get more info
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info with the version ID
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# If we found LoRA hashes in the metadata but haven't already
|
||||
# populated entries for them, fall back to creating LoRAs from
|
||||
# the hashes section. Some Civitai image responses only include
|
||||
# LoRA information here without explicit resources entries.
|
||||
for lora_name, lora_hash in lora_hashes.items():
|
||||
if not lora_hash:
|
||||
continue
|
||||
|
||||
# Skip LoRAs we've already added via resources or other fields
|
||||
if lora_hash in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': "lora",
|
||||
'weight': 1.0,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_hash}: {e}")
|
||||
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
|
||||
lora_index = 0
|
||||
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
|
||||
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
|
||||
lora_name = metadata[f"Lora_{lora_index} Model name"]
|
||||
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
lora_index += 1
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': "lora",
|
||||
'weight': lora_strength_model,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
lora_index += 1
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
lora_index += 1
|
||||
|
||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||
if not result["base_model"] and base_model_counts:
|
||||
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
220
py/recipes/parsers/comfy.py
Normal file
220
py/recipes/parsers/comfy.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Parser for ComfyUI metadata format."""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ComfyMetadataParser(RecipeMetadataParser):
|
||||
"""Parser for Civitai ComfyUI metadata JSON format"""
|
||||
|
||||
METADATA_MARKER = r"class_type"
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the ComfyUI metadata format"""
|
||||
try:
|
||||
data = json.loads(user_comment)
|
||||
# Check if it contains class_type nodes typical of ComfyUI workflow
|
||||
return isinstance(data, dict) and any(isinstance(v, dict) and 'class_type' in v for v in data.values())
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return False
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Civitai ComfyUI metadata format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
data = json.loads(user_comment)
|
||||
loras = []
|
||||
|
||||
# Find all LoraLoader nodes
|
||||
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
|
||||
|
||||
if not lora_nodes:
|
||||
return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []}
|
||||
|
||||
# Process each LoraLoader node
|
||||
for node_id, node in lora_nodes.items():
|
||||
if 'inputs' not in node or 'lora_name' not in node['inputs']:
|
||||
continue
|
||||
|
||||
lora_name = node['inputs'].get('lora_name', '')
|
||||
|
||||
# Parse the URN to extract model ID and version ID
|
||||
# Format: "urn:air:sdxl:lora:civitai:1107767@1253442"
|
||||
lora_id_match = re.search(r'civitai:(\d+)@(\d+)', lora_name)
|
||||
if not lora_id_match:
|
||||
continue
|
||||
|
||||
model_id = lora_id_match.group(1)
|
||||
model_version_id = lora_id_match.group(2)
|
||||
|
||||
# Get strength from node inputs
|
||||
weight = node['inputs'].get('strength_model', 1.0)
|
||||
|
||||
# Initialize lora entry with default values
|
||||
lora_entry = {
|
||||
'id': model_version_id,
|
||||
'modelId': model_id,
|
||||
'name': f"Lora {model_id}", # Default name
|
||||
'version': '',
|
||||
'type': 'lora',
|
||||
'weight': weight,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': '',
|
||||
'hash': '',
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get additional info from Civitai if metadata provider is available
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(model_version_id)
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info_tuple,
|
||||
recipe_scanner
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Find checkpoint info
|
||||
checkpoint_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'CheckpointLoaderSimple'}
|
||||
checkpoint = None
|
||||
checkpoint_id = None
|
||||
checkpoint_version_id = None
|
||||
|
||||
if checkpoint_nodes:
|
||||
# Get the first checkpoint node
|
||||
checkpoint_node = next(iter(checkpoint_nodes.values()))
|
||||
if 'inputs' in checkpoint_node and 'ckpt_name' in checkpoint_node['inputs']:
|
||||
checkpoint_name = checkpoint_node['inputs']['ckpt_name']
|
||||
# Parse checkpoint URN
|
||||
checkpoint_match = re.search(r'civitai:(\d+)@(\d+)', checkpoint_name)
|
||||
if checkpoint_match:
|
||||
checkpoint_id = checkpoint_match.group(1)
|
||||
checkpoint_version_id = checkpoint_match.group(2)
|
||||
checkpoint = {
|
||||
'id': checkpoint_version_id,
|
||||
'modelId': checkpoint_id,
|
||||
'name': f"Checkpoint {checkpoint_id}",
|
||||
'version': '',
|
||||
'type': 'checkpoint'
|
||||
}
|
||||
|
||||
# Get additional checkpoint info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(checkpoint_version_id)
|
||||
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
# Populate checkpoint with Civitai info
|
||||
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for checkpoint: {e}")
|
||||
|
||||
# Extract generation parameters
|
||||
gen_params = {}
|
||||
|
||||
# First try to get from extraMetadata
|
||||
if 'extraMetadata' in data:
|
||||
try:
|
||||
# extraMetadata is a JSON string that needs to be parsed
|
||||
extra_metadata = json.loads(data['extraMetadata'])
|
||||
|
||||
# Map fields from extraMetadata to our standard format
|
||||
mapping = {
|
||||
'prompt': 'prompt',
|
||||
'negativePrompt': 'negative_prompt',
|
||||
'steps': 'steps',
|
||||
'sampler': 'sampler',
|
||||
'cfgScale': 'cfg_scale',
|
||||
'seed': 'seed'
|
||||
}
|
||||
|
||||
for src_key, dest_key in mapping.items():
|
||||
if src_key in extra_metadata:
|
||||
gen_params[dest_key] = extra_metadata[src_key]
|
||||
|
||||
# If size info is available, format as "width x height"
|
||||
if 'width' in extra_metadata and 'height' in extra_metadata:
|
||||
gen_params['size'] = f"{extra_metadata['width']}x{extra_metadata['height']}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing extraMetadata: {e}")
|
||||
|
||||
# If extraMetadata doesn't have all the info, try to get from nodes
|
||||
if not gen_params or len(gen_params) < 3: # At least we want prompt, negative_prompt, and steps
|
||||
# Find positive prompt node
|
||||
positive_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and
|
||||
v.get('class_type', '').endswith('CLIPTextEncode') and
|
||||
v.get('_meta', {}).get('title') == 'Positive'}
|
||||
|
||||
if positive_nodes:
|
||||
positive_node = next(iter(positive_nodes.values()))
|
||||
if 'inputs' in positive_node and 'text' in positive_node['inputs']:
|
||||
gen_params['prompt'] = positive_node['inputs']['text']
|
||||
|
||||
# Find negative prompt node
|
||||
negative_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and
|
||||
v.get('class_type', '').endswith('CLIPTextEncode') and
|
||||
v.get('_meta', {}).get('title') == 'Negative'}
|
||||
|
||||
if negative_nodes:
|
||||
negative_node = next(iter(negative_nodes.values()))
|
||||
if 'inputs' in negative_node and 'text' in negative_node['inputs']:
|
||||
gen_params['negative_prompt'] = negative_node['inputs']['text']
|
||||
|
||||
# Find KSampler node for other parameters
|
||||
ksampler_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'KSampler'}
|
||||
|
||||
if ksampler_nodes:
|
||||
ksampler_node = next(iter(ksampler_nodes.values()))
|
||||
if 'inputs' in ksampler_node:
|
||||
inputs = ksampler_node['inputs']
|
||||
if 'sampler_name' in inputs:
|
||||
gen_params['sampler'] = inputs['sampler_name']
|
||||
if 'steps' in inputs:
|
||||
gen_params['steps'] = inputs['steps']
|
||||
if 'cfg' in inputs:
|
||||
gen_params['cfg_scale'] = inputs['cfg']
|
||||
if 'seed' in inputs:
|
||||
gen_params['seed'] = inputs['seed']
|
||||
|
||||
# Determine base model from loras info
|
||||
base_model = None
|
||||
if loras:
|
||||
# Use the most common base model from loras
|
||||
base_models = [lora['baseModel'] for lora in loras if lora.get('baseModel')]
|
||||
if base_models:
|
||||
from collections import Counter
|
||||
base_model_counts = Counter(base_models)
|
||||
base_model = base_model_counts.most_common(1)[0][0]
|
||||
|
||||
return {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'checkpoint': checkpoint,
|
||||
'gen_params': gen_params,
|
||||
'from_comfy_metadata': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing ComfyUI metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
219
py/recipes/parsers/meta_format.py
Normal file
219
py/recipes/parsers/meta_format.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Parser for meta format (Lora_N Model hash) metadata."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetaFormatParser(RecipeMetadataParser):
|
||||
"""Parser for images with meta format metadata (Lora_N Model hash format)"""
|
||||
|
||||
METADATA_MARKER = r'Lora_\d+ Model hash:'
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the metadata format"""
|
||||
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from images with meta format metadata (Lora_N Model hash format)"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
parts = user_comment.split('Negative prompt:', 1)
|
||||
prompt = parts[0].strip()
|
||||
|
||||
# Initialize metadata
|
||||
metadata = {"prompt": prompt, "loras": []}
|
||||
|
||||
# Extract negative prompt and parameters if available
|
||||
if len(parts) > 1:
|
||||
negative_and_params = parts[1]
|
||||
|
||||
# Extract negative prompt - everything until the first parameter (usually "Steps:")
|
||||
param_start = re.search(r'([A-Za-z]+): ', negative_and_params)
|
||||
if param_start:
|
||||
neg_prompt = negative_and_params[:param_start.start()].strip()
|
||||
metadata["negative_prompt"] = neg_prompt
|
||||
params_section = negative_and_params[param_start.start():]
|
||||
else:
|
||||
params_section = negative_and_params
|
||||
|
||||
# Extract key-value parameters (Steps, Sampler, Seed, etc.)
|
||||
param_pattern = r'([A-Za-z_0-9 ]+): ([^,]+)'
|
||||
params = re.findall(param_pattern, params_section)
|
||||
for key, value in params:
|
||||
clean_key = key.strip().lower().replace(' ', '_')
|
||||
metadata[clean_key] = value.strip()
|
||||
|
||||
# Extract LoRA information
|
||||
# Pattern to match lora entries: Lora_0 Model name: ArtVador I.safetensors, Lora_0 Model hash: 08f7133a58, etc.
|
||||
lora_pattern = r'Lora_(\d+) Model name: ([^,]+), Lora_\1 Model hash: ([^,]+), Lora_\1 Strength model: ([^,]+), Lora_\1 Strength clip: ([^,]+)'
|
||||
lora_matches = re.findall(lora_pattern, user_comment)
|
||||
|
||||
# If the regular pattern doesn't match, try a more flexible approach
|
||||
if not lora_matches:
|
||||
# First find all Lora indices
|
||||
lora_indices = set(re.findall(r'Lora_(\d+)', user_comment))
|
||||
|
||||
# For each index, extract the information
|
||||
for idx in lora_indices:
|
||||
lora_info = {}
|
||||
|
||||
# Extract model name
|
||||
name_match = re.search(f'Lora_{idx} Model name: ([^,]+)', user_comment)
|
||||
if name_match:
|
||||
lora_info['name'] = name_match.group(1).strip()
|
||||
|
||||
# Extract model hash
|
||||
hash_match = re.search(f'Lora_{idx} Model hash: ([^,]+)', user_comment)
|
||||
if hash_match:
|
||||
lora_info['hash'] = hash_match.group(1).strip()
|
||||
|
||||
# Extract strength model
|
||||
strength_model_match = re.search(f'Lora_{idx} Strength model: ([^,]+)', user_comment)
|
||||
if strength_model_match:
|
||||
lora_info['strength_model'] = float(strength_model_match.group(1).strip())
|
||||
|
||||
# Extract strength clip
|
||||
strength_clip_match = re.search(f'Lora_{idx} Strength clip: ([^,]+)', user_comment)
|
||||
if strength_clip_match:
|
||||
lora_info['strength_clip'] = float(strength_clip_match.group(1).strip())
|
||||
|
||||
# Only add if we have at least name and hash
|
||||
if 'name' in lora_info and 'hash' in lora_info:
|
||||
lora_matches.append((idx, lora_info['name'], lora_info['hash'],
|
||||
str(lora_info.get('strength_model', 1.0)),
|
||||
str(lora_info.get('strength_clip', 1.0))))
|
||||
|
||||
# Process LoRAs
|
||||
base_model_counts = {}
|
||||
loras = []
|
||||
|
||||
for match in lora_matches:
|
||||
if len(match) == 5: # Regular pattern match
|
||||
idx, name, hash_value, strength_model, strength_clip = match
|
||||
else: # Flexible approach match
|
||||
continue # Should not happen now
|
||||
|
||||
# Clean up the values
|
||||
name = name.strip()
|
||||
if name.endswith('.safetensors'):
|
||||
name = name[:-12] # Remove .safetensors extension
|
||||
|
||||
hash_value = hash_value.strip()
|
||||
weight = float(strength_model) # Use model strength as weight
|
||||
|
||||
# Initialize lora entry with default values
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': 'lora',
|
||||
'weight': weight,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'hash': hash_value,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get info from Civitai by hash if available
|
||||
if metadata_provider and hash_value:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(hash_value)
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
hash_value
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Extract checkpoint information from generic Model/Model hash fields
|
||||
checkpoint = None
|
||||
model_hash = metadata.get("model_hash")
|
||||
model_name = metadata.get("model")
|
||||
|
||||
if model_hash or model_name:
|
||||
cleaned_name = None
|
||||
if model_name:
|
||||
cleaned_name = re.split(r"[\\\\/]", model_name)[-1]
|
||||
cleaned_name = os.path.splitext(cleaned_name)[0]
|
||||
|
||||
checkpoint_entry = {
|
||||
'id': 0,
|
||||
'modelId': 0,
|
||||
'name': model_name or "Unknown Checkpoint",
|
||||
'version': '',
|
||||
'type': 'checkpoint',
|
||||
'hash': model_hash or "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': cleaned_name or (model_name or ""),
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
if metadata_provider and model_hash:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(model_hash)
|
||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(
|
||||
checkpoint_entry,
|
||||
civitai_info
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for checkpoint hash {model_hash}: {e}")
|
||||
|
||||
if checkpoint_entry.get("baseModel"):
|
||||
base_model_value = checkpoint_entry["baseModel"]
|
||||
base_model_counts[base_model_value] = base_model_counts.get(base_model_value, 0) + 1
|
||||
|
||||
checkpoint = checkpoint_entry
|
||||
|
||||
# Set base_model to the most common one from civitai_info or checkpoint
|
||||
base_model = checkpoint["baseModel"] if checkpoint and checkpoint.get("baseModel") else None
|
||||
if not base_model and base_model_counts:
|
||||
base_model = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
# Extract generation parameters for recipe metadata
|
||||
gen_params = {}
|
||||
for key in GEN_PARAM_KEYS:
|
||||
if key in metadata:
|
||||
gen_params[key] = metadata.get(key, '')
|
||||
|
||||
# Try to extract size information if available
|
||||
if 'width' in metadata and 'height' in metadata:
|
||||
gen_params['size'] = f"{metadata['width']}x{metadata['height']}"
|
||||
|
||||
return {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'gen_params': gen_params,
|
||||
'raw_metadata': metadata,
|
||||
'from_meta_format': True,
|
||||
**({'checkpoint': checkpoint, 'model': checkpoint} if checkpoint else {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing meta format metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
202
py/recipes/parsers/recipe_format.py
Normal file
202
py/recipes/parsers/recipe_format.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Parser for dedicated recipe metadata format."""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from ...config import config
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeFormatParser(RecipeMetadataParser):
|
||||
"""Parser for images with dedicated recipe metadata format"""
|
||||
|
||||
# Regular expression pattern for extracting recipe metadata
|
||||
METADATA_MARKER = r'Recipe metadata: (\{.*\})'
|
||||
|
||||
async def _get_lora_from_version_index(self, recipe_scanner, model_version_id: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Return a cached LoRA entry by modelVersionId if available."""
|
||||
|
||||
if not recipe_scanner or not getattr(recipe_scanner, "_lora_scanner", None):
|
||||
return None
|
||||
|
||||
try:
|
||||
normalized_id = int(model_version_id)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
try:
|
||||
cache = await recipe_scanner._lora_scanner.get_cached_data()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.debug("Unable to load lora cache for version lookup: %s", exc)
|
||||
return None
|
||||
|
||||
if not cache or not getattr(cache, "version_index", None):
|
||||
return None
|
||||
|
||||
return cache.version_index.get(normalized_id)
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the metadata format"""
|
||||
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from images with dedicated recipe metadata format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Extract recipe metadata from user comment
|
||||
try:
|
||||
# Look for recipe metadata section
|
||||
recipe_match = re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL)
|
||||
if not recipe_match:
|
||||
recipe_metadata = None
|
||||
else:
|
||||
recipe_json = recipe_match.group(1)
|
||||
recipe_metadata = json.loads(recipe_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting recipe metadata: {e}")
|
||||
recipe_metadata = None
|
||||
if not recipe_metadata:
|
||||
return {"error": "No recipe metadata found", "loras": []}
|
||||
|
||||
# Process the recipe metadata
|
||||
loras = []
|
||||
for lora in recipe_metadata.get('loras', []):
|
||||
# Convert recipe lora format to frontend format
|
||||
lora_entry = {
|
||||
'id': int(lora.get('modelVersionId', 0)),
|
||||
'name': lora.get('modelName', ''),
|
||||
'version': lora.get('modelVersionName', ''),
|
||||
'type': 'lora',
|
||||
'weight': lora.get('strength', 1.0),
|
||||
'file_name': lora.get('file_name', ''),
|
||||
'hash': lora.get('hash', ''),
|
||||
'existsLocally': False,
|
||||
'inLibrary': False,
|
||||
'localPath': None,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'size': 0
|
||||
}
|
||||
|
||||
# Check if this LoRA exists locally by SHA256 hash
|
||||
if recipe_scanner:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
|
||||
if lora.get('hash'):
|
||||
exists_locally = lora_scanner.has_hash(lora['hash'])
|
||||
if exists_locally:
|
||||
lora_cache = await lora_scanner.get_cached_data()
|
||||
lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None)
|
||||
if lora_item:
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['inLibrary'] = True
|
||||
lora_entry['localPath'] = lora_item['file_path']
|
||||
lora_entry['file_name'] = lora_item['file_name']
|
||||
lora_entry['size'] = lora_item['size']
|
||||
lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url'])
|
||||
|
||||
else:
|
||||
lora_entry['existsLocally'] = False
|
||||
lora_entry['inLibrary'] = False
|
||||
lora_entry['localPath'] = None
|
||||
|
||||
# If we still don't have a local match, try matching by modelVersionId
|
||||
if not lora_entry['existsLocally'] and lora.get('modelVersionId') is not None:
|
||||
cached_lora = await self._get_lora_from_version_index(recipe_scanner, lora.get('modelVersionId'))
|
||||
if cached_lora:
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['inLibrary'] = True
|
||||
lora_entry['localPath'] = cached_lora.get('file_path')
|
||||
lora_entry['file_name'] = cached_lora.get('file_name') or lora_entry['file_name']
|
||||
lora_entry['size'] = cached_lora.get('size', lora_entry['size'])
|
||||
if cached_lora.get('sha256'):
|
||||
lora_entry['hash'] = cached_lora['sha256']
|
||||
preview_url = cached_lora.get('preview_url')
|
||||
if preview_url:
|
||||
lora_entry['thumbnailUrl'] = config.get_preview_static_url(preview_url)
|
||||
|
||||
# Try to get additional info from Civitai if we have a model version ID and still missing locally
|
||||
if not lora_entry['existsLocally'] and lora.get('modelVersionId') and metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(lora['modelVersionId'])
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info_tuple,
|
||||
recipe_scanner,
|
||||
None, # No need to track base model counts
|
||||
lora_entry.get('hash', '')
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
logger.info(f"Found {len(loras)} loras in recipe metadata")
|
||||
|
||||
# Process checkpoint information if present
|
||||
checkpoint = None
|
||||
checkpoint_data = recipe_metadata.get('checkpoint') or {}
|
||||
if isinstance(checkpoint_data, dict) and checkpoint_data:
|
||||
version_id = checkpoint_data.get('modelVersionId') or checkpoint_data.get('id')
|
||||
checkpoint_entry = {
|
||||
'id': version_id or 0,
|
||||
'modelId': checkpoint_data.get('modelId', 0),
|
||||
'name': checkpoint_data.get('name', 'Unknown Checkpoint'),
|
||||
'version': checkpoint_data.get('version', ''),
|
||||
'type': checkpoint_data.get('type', 'checkpoint'),
|
||||
'hash': checkpoint_data.get('hash', ''),
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': checkpoint_data.get('file_name', ''),
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info = None
|
||||
if version_id:
|
||||
civitai_info = await metadata_provider.get_model_version_info(str(version_id))
|
||||
elif checkpoint_entry.get('hash'):
|
||||
civitai_info = await metadata_provider.get_model_by_hash(checkpoint_entry['hash'])
|
||||
|
||||
if civitai_info:
|
||||
checkpoint_entry = await self.populate_checkpoint_from_civitai(checkpoint_entry, civitai_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for checkpoint in recipe metadata: {e}")
|
||||
|
||||
checkpoint = checkpoint_entry
|
||||
|
||||
# Filter gen_params to only include recognized keys
|
||||
filtered_gen_params = {}
|
||||
if 'gen_params' in recipe_metadata:
|
||||
for key, value in recipe_metadata['gen_params'].items():
|
||||
if key in GEN_PARAM_KEYS:
|
||||
filtered_gen_params[key] = value
|
||||
|
||||
return {
|
||||
'base_model': checkpoint['baseModel'] if checkpoint and checkpoint.get('baseModel') else recipe_metadata.get('base_model', ''),
|
||||
'loras': loras,
|
||||
'gen_params': filtered_gen_params,
|
||||
'tags': recipe_metadata.get('tags', []),
|
||||
'title': recipe_metadata.get('title', ''),
|
||||
'from_recipe_metadata': True,
|
||||
**({'checkpoint': checkpoint, 'model': checkpoint} if checkpoint else {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing recipe format metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
@@ -1,997 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
from ..config import config
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.settings_manager import settings
|
||||
import asyncio
|
||||
from .update_routes import UpdateRoutes
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiRoutes:
|
||||
"""API route handlers for LoRA management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.civitai_client = None # Will be initialized in setup_routes
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls()
|
||||
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: routes.initialize_services())
|
||||
|
||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
app.router.add_get('/api/loras', routes.get_loras)
|
||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
|
||||
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
||||
app.router.add_get('/api/folders', routes.get_folders)
|
||||
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
|
||||
app.router.add_get('/api/civitai/model/{modelVersionId}', routes.get_civitai_model)
|
||||
app.router.add_get('/api/civitai/model/{hash}', routes.get_civitai_model)
|
||||
app.router.add_post('/api/download-lora', routes.download_lora)
|
||||
app.router.add_post('/api/settings', routes.update_settings)
|
||||
app.router.add_post('/api/move_model', routes.move_model)
|
||||
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
||||
app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
|
||||
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
|
||||
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
|
||||
app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags
|
||||
app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models
|
||||
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
|
||||
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
||||
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
||||
|
||||
# Add update check routes
|
||||
UpdateRoutes.setup_routes(app)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement request"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def scan_loras(self, request: web.Request) -> web.Response:
|
||||
"""Force a rescan of LoRA files"""
|
||||
try:
|
||||
await self.scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"status": "success", "message": "LoRA scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_loras: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_loras(self, request: web.Request) -> web.Response:
|
||||
"""Handle paginated LoRA data request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = int(request.query.get('page_size', '20'))
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true'
|
||||
|
||||
# Parse search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true'
|
||||
}
|
||||
|
||||
# Get filter parameters
|
||||
base_models = request.query.get('base_models', None)
|
||||
tags = request.query.get('tags', None)
|
||||
|
||||
# New parameters for recipe filtering
|
||||
lora_hash = request.query.get('lora_hash', None)
|
||||
lora_hashes = request.query.get('lora_hashes', None)
|
||||
|
||||
# Parse filter parameters
|
||||
filters = {}
|
||||
if base_models:
|
||||
filters['base_model'] = base_models.split(',')
|
||||
if tags:
|
||||
filters['tags'] = tags.split(',')
|
||||
|
||||
# Add lora hash filtering options
|
||||
hash_filters = {}
|
||||
if lora_hash:
|
||||
hash_filters['single_hash'] = lora_hash.lower()
|
||||
elif lora_hashes:
|
||||
hash_filters['multiple_hashes'] = [h.lower() for h in lora_hashes.split(',')]
|
||||
|
||||
# Get file data
|
||||
data = await self.scanner.get_paginated_data(
|
||||
page,
|
||||
page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy_search=fuzzy_search,
|
||||
base_models=filters.get('base_model', None),
|
||||
tags=filters.get('tags', None),
|
||||
search_options=search_options,
|
||||
hash_filters=hash_filters
|
||||
)
|
||||
|
||||
# Get all available folders from cache
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Convert output to match expected format
|
||||
result = {
|
||||
'items': [self._format_lora_response(lora) for lora in data['items']],
|
||||
'folders': cache.folders,
|
||||
'total': data['total'],
|
||||
'page': data['page'],
|
||||
'page_size': data['page_size'],
|
||||
'total_pages': data['total_pages']
|
||||
}
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving loras: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
def _format_lora_response(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora["size"],
|
||||
"modified": lora["modified"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"usage_tips": lora.get("usage_tips", ""),
|
||||
"notes": lora.get("notes", ""),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
# Private helper methods
|
||||
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
|
||||
"""Read preview file and content type from multipart request"""
|
||||
field = await reader.next()
|
||||
if field.name != 'preview_file':
|
||||
raise ValueError("Expected 'preview_file' field")
|
||||
content_type = field.headers.get('Content-Type', 'image/png')
|
||||
return await field.read(), content_type
|
||||
|
||||
async def _read_model_path(self, reader) -> str:
|
||||
"""Read model path from multipart request"""
|
||||
field = await reader.next()
|
||||
if field.name != 'model_path':
|
||||
raise ValueError("Expected 'model_path' field")
|
||||
return (await field.read()).decode()
|
||||
|
||||
async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str:
|
||||
"""Save preview file and return its path"""
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
folder = os.path.dirname(model_path)
|
||||
|
||||
# Determine if content is video or image
|
||||
if content_type.startswith('video/'):
|
||||
# For videos, keep original format and use .mp4 extension
|
||||
extension = '.mp4'
|
||||
optimized_data = preview_data
|
||||
else:
|
||||
# For images, optimize and convert to WebP
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=preview_data,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=True
|
||||
)
|
||||
extension = '.webp' # Use .webp without .preview part
|
||||
|
||||
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
|
||||
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(optimized_data)
|
||||
|
||||
return preview_path
|
||||
|
||||
async def _update_preview_metadata(self, model_path: str, preview_path: str):
|
||||
"""Update preview path in metadata"""
|
||||
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Update preview_url directly in the metadata dict
|
||||
metadata['preview_url'] = preview_path
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating metadata: {e}")
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all loras in the background"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare loras to process
|
||||
to_process = [
|
||||
lora for lora in cache.raw_data
|
||||
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
for lora in to_process:
|
||||
try:
|
||||
original_name = lora.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=lora['sha256'],
|
||||
file_path=lora['file_path'],
|
||||
model_data=lora,
|
||||
update_cache_func=self.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != lora.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': lora.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed loras (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_roots(self, request: web.Request) -> web.Response:
|
||||
"""Get all configured LoRA root directories"""
|
||||
return web.json_response({
|
||||
'roots': config.loras_roots
|
||||
})
|
||||
|
||||
async def get_folders(self, request: web.Request) -> web.Response:
|
||||
"""Get all folders in the cache"""
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
return web.json_response({
|
||||
'folders': cache.folders
|
||||
})
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be LORA
|
||||
if model_type.lower() != 'lora':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected LORA, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the model file (type="Model") in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def get_civitai_model(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID or hash"""
|
||||
try:
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
if not model_version_id:
|
||||
hash = request.match_info.get('hash')
|
||||
model = await self.civitai_client.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
|
||||
# Get model details from Civitai API
|
||||
model = await self.civitai_client.get_model_version_info(model_version_id)
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
|
||||
async def download_lora(self, request: web.Request) -> web.Response:
|
||||
async with self._download_lock:
|
||||
try:
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
async def progress_callback(progress):
|
||||
await ws_manager.broadcast({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=data.get('lora_root'),
|
||||
relative_path=data.get('relative_path'),
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401, # Use 401 status code to match Civitai's response
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
return web.json_response(result)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This LoRA requires purchase. Please buy early access on Civitai.com."
|
||||
)
|
||||
|
||||
logger.error(f"Error downloading LoRA: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def update_settings(self, request: web.Request) -> web.Response:
|
||||
"""Update application settings"""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Validate and update settings
|
||||
if 'civitai_api_key' in data:
|
||||
settings.set('civitai_api_key', data['civitai_api_key'])
|
||||
if 'show_only_sfw' in data:
|
||||
settings.set('show_only_sfw', data['show_only_sfw'])
|
||||
|
||||
return web.json_response({'success': True})
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating settings: {e}", exc_info=True)
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model move request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
|
||||
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
|
||||
|
||||
if not file_path or not target_path:
|
||||
return web.Response(text='File path and target path are required', status=400)
|
||||
|
||||
# Check if source and destination are the same
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
|
||||
|
||||
# Check if target file already exists
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Target file already exists: {target_file_path}"
|
||||
}, status=409) # 409 Conflict
|
||||
|
||||
# Call scanner to handle the move operation
|
||||
success = await self.scanner.move_model(file_path, target_path)
|
||||
|
||||
if success:
|
||||
return web.json_response({'success': True})
|
||||
else:
|
||||
return web.Response(text='Failed to move model', status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@classmethod
|
||||
async def cleanup(cls):
|
||||
"""Add cleanup method for application shutdown"""
|
||||
# Now we don't need to store an instance, as services are managed by ServiceRegistry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Handle nested updates (for civitai.trainedWords)
|
||||
for key, value in metadata_updates.items():
|
||||
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||
# Deep update for nested dictionaries
|
||||
for nested_key, nested_value in value.items():
|
||||
metadata[key][nested_key] = nested_value
|
||||
else:
|
||||
# Regular update for top-level keys
|
||||
metadata[key] = value
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
# Get cache data
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Search for the lora in cache data
|
||||
for lora in cache.raw_data:
|
||||
file_name = lora['file_name']
|
||||
if file_name == lora_name:
|
||||
if preview_url := lora.get('preview_url'):
|
||||
# Convert preview path to static URL
|
||||
static_url = config.get_preview_static_url(preview_url)
|
||||
if static_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': static_url
|
||||
})
|
||||
break
|
||||
|
||||
# If no preview URL found
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
# Get cache data
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Search for the lora in cache data
|
||||
for lora in cache.raw_data:
|
||||
file_name = lora['file_name']
|
||||
if file_name == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': model_id,
|
||||
'version_id': version_id
|
||||
})
|
||||
break
|
||||
|
||||
# If no Civitai data found
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk model move request"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
|
||||
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
|
||||
|
||||
if not file_paths or not target_path:
|
||||
return web.Response(text='File paths and target path are required', status=400)
|
||||
|
||||
results = []
|
||||
for file_path in file_paths:
|
||||
# Check if source and destination are the same
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": True,
|
||||
"message": "Source and target directories are the same"
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if target file already exists
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": False,
|
||||
"message": f"Target file already exists: {target_file_path}"
|
||||
})
|
||||
continue
|
||||
|
||||
# Try to move the model
|
||||
success = await self.scanner.move_model(file_path, target_path)
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": success,
|
||||
"message": "Success" if success else "Failed to move model"
|
||||
})
|
||||
|
||||
# Count successes and failures
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||
'results': results,
|
||||
'success_count': success_count,
|
||||
'failure_count': failure_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
|
||||
if not model_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Check if we already have the description stored in metadata
|
||||
description = None
|
||||
tags = []
|
||||
if file_path:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
description = metadata.get('modelDescription')
|
||||
tags = metadata.get('tags', [])
|
||||
|
||||
# If description is not in metadata, fetch from CivitAI
|
||||
if not description:
|
||||
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
||||
|
||||
if model_metadata:
|
||||
description = model_metadata.get('description')
|
||||
tags = model_metadata.get('tags', [])
|
||||
|
||||
# Save the metadata to file if we have a file path and got metadata
|
||||
if file_path:
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
metadata['modelDescription'] = description
|
||||
metadata['tags'] = tags
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved model metadata to file for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model metadata: {e}")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'description': description or "<p>No model description available.</p>",
|
||||
'tags': tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model metadata: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get base models
|
||||
base_models = await self.scanner.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def rename_lora(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a LoRA file and its associated files"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
new_file_name = data.get('new_file_name')
|
||||
|
||||
if not file_path or not new_file_name:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'File path and new file name are required'
|
||||
}, status=400)
|
||||
|
||||
# Validate the new file name (no path separators or invalid characters)
|
||||
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
|
||||
if any(char in new_file_name for char in invalid_chars):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Invalid characters in file name'
|
||||
}, status=400)
|
||||
|
||||
# Get the directory and current file name
|
||||
target_dir = os.path.dirname(file_path)
|
||||
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# Check if the target file already exists
|
||||
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(os.sep, '/')
|
||||
if os.path.exists(new_file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'A file with this name already exists'
|
||||
}, status=400)
|
||||
|
||||
# Define the patterns for associated files
|
||||
patterns = [
|
||||
f"{old_file_name}.safetensors", # Required
|
||||
f"{old_file_name}.metadata.json",
|
||||
]
|
||||
|
||||
# Add all preview file extensions
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{old_file_name}{ext}")
|
||||
|
||||
# Find all matching files
|
||||
existing_files = []
|
||||
for pattern in patterns:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
existing_files.append((path, pattern))
|
||||
|
||||
# Get the hash from the main file to update hash index
|
||||
hash_value = None
|
||||
metadata = None
|
||||
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
hash_value = metadata.get('sha256')
|
||||
|
||||
# Rename all files
|
||||
renamed_files = []
|
||||
new_metadata_path = None
|
||||
|
||||
# Notify file monitor to ignore these events
|
||||
main_file_path = os.path.join(target_dir, f"{old_file_name}.safetensors")
|
||||
if os.path.exists(main_file_path):
|
||||
# Get lora monitor through ServiceRegistry instead of download_manager
|
||||
lora_monitor = await ServiceRegistry.get_lora_monitor()
|
||||
if lora_monitor:
|
||||
# Add old and new paths to ignore list
|
||||
file_size = os.path.getsize(main_file_path)
|
||||
lora_monitor.handler.add_ignore_path(main_file_path, file_size)
|
||||
lora_monitor.handler.add_ignore_path(new_file_path, file_size)
|
||||
|
||||
for old_path, pattern in existing_files:
|
||||
# Get the file extension like .safetensors or .metadata.json
|
||||
ext = ModelRouteUtils.get_multipart_ext(pattern)
|
||||
|
||||
# Create the new path
|
||||
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
|
||||
|
||||
# Rename the file
|
||||
os.rename(old_path, new_path)
|
||||
renamed_files.append(new_path)
|
||||
|
||||
# Keep track of metadata path for later update
|
||||
if ext == '.metadata.json':
|
||||
new_metadata_path = new_path
|
||||
|
||||
# Update the metadata file with new file name and paths
|
||||
if new_metadata_path and metadata:
|
||||
# Update file_name, file_path and preview_url in metadata
|
||||
metadata['file_name'] = new_file_name
|
||||
metadata['file_path'] = new_file_path
|
||||
|
||||
# Update preview_url if it exists
|
||||
if 'preview_url' in metadata and metadata['preview_url']:
|
||||
old_preview = metadata['preview_url']
|
||||
ext = ModelRouteUtils.get_multipart_ext(old_preview)
|
||||
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
|
||||
metadata['preview_url'] = new_preview
|
||||
|
||||
# Save updated metadata
|
||||
with open(new_metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update the scanner cache
|
||||
if metadata:
|
||||
await self.scanner.update_single_model_cache(file_path, new_file_path, metadata)
|
||||
|
||||
# Update recipe files and cache if hash is available
|
||||
if hash_value:
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
|
||||
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'new_file_path': new_file_path,
|
||||
'renamed_files': renamed_files,
|
||||
'reload_required': False
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error renaming LoRA: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
300
py/routes/base_model_routes.py
Normal file
300
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..services.download_coordinator import DownloadCoordinator
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.model_lifecycle_service import ModelLifecycleService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.use_cases import (
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelUseCase,
|
||||
)
|
||||
from ..services.websocket_progress_callback import (
|
||||
WebSocketBroadcastCallback,
|
||||
WebSocketProgressCallback,
|
||||
)
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||
from .handlers.model_handlers import (
|
||||
ModelAutoOrganizeHandler,
|
||||
ModelCivitaiHandler,
|
||||
ModelDownloadHandler,
|
||||
ModelHandlerSet,
|
||||
ModelListingHandler,
|
||||
ModelManagementHandler,
|
||||
ModelMoveHandler,
|
||||
ModelPageView,
|
||||
ModelQueryHandler,
|
||||
ModelUpdateHandler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.model_update_service import ModelUpdateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModelRoutes(ABC):
|
||||
"""Base route controller for all model types."""
|
||||
|
||||
template_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service=None,
|
||||
*,
|
||||
settings_service=None,
|
||||
ws_manager=default_ws_manager,
|
||||
server_i18n=default_server_i18n,
|
||||
metadata_provider_factory=get_default_metadata_provider,
|
||||
) -> None:
|
||||
self.service = None
|
||||
self.model_type = ""
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
self._ws_manager = ws_manager
|
||||
self._server_i18n = server_i18n
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self.model_file_service: ModelFileService | None = None
|
||||
self.model_move_service: ModelMoveService | None = None
|
||||
self.model_lifecycle_service: ModelLifecycleService | None = None
|
||||
self.websocket_progress_callback = WebSocketProgressCallback()
|
||||
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
||||
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
self._preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=self._settings,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||
self._download_coordinator = DownloadCoordinator(
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
self._model_update_service: ModelUpdateService | None = None
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
def set_model_update_service(self, service: "ModelUpdateService") -> None:
|
||||
"""Attach the model update tracking service."""
|
||||
|
||||
self._model_update_service = service
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
def attach_service(self, service) -> None:
|
||||
"""Attach a model service and rebuild handler dependencies."""
|
||||
self.service = service
|
||||
self.model_type = service.model_type
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
update_service=self._model_update_service,
|
||||
)
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
def _create_handler_set(self) -> ModelHandlerSet:
|
||||
service = self._ensure_service()
|
||||
update_service = self._ensure_model_update_service()
|
||||
page_view = ModelPageView(
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name or "",
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
server_i18n=self._server_i18n,
|
||||
logger=logger,
|
||||
)
|
||||
listing = ModelListingHandler(
|
||||
service=service,
|
||||
parse_specific_params=self._parse_specific_params,
|
||||
logger=logger,
|
||||
)
|
||||
management = ModelManagementHandler(
|
||||
service=service,
|
||||
logger=logger,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
preview_service=self._preview_service,
|
||||
tag_update_service=self._tag_update_service,
|
||||
lifecycle_service=self._ensure_lifecycle_service(),
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_use_case=download_use_case,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
settings_service=self._settings,
|
||||
logger=logger,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
validate_model_type=self._validate_civitai_model_type,
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
metadata_refresh_use_case=metadata_refresh_use_case,
|
||||
metadata_progress_callback=self.metadata_progress_callback,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize_use_case = AutoOrganizeUseCase(
|
||||
file_service=self._ensure_file_service(),
|
||||
lock_provider=self._ws_manager,
|
||||
)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
use_case=auto_organize_use_case,
|
||||
progress_callback=self.websocket_progress_callback,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
updates = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
logger=logger,
|
||||
)
|
||||
return ModelHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
management=management,
|
||||
query=query,
|
||||
download=download,
|
||||
civitai=civitai,
|
||||
move=move,
|
||||
auto_organize=auto_organize,
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
@property
|
||||
def route_handlers(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
return self._ensure_handler_mapping()
|
||||
|
||||
def setup_routes(self, app: web.Application, prefix: str) -> None:
|
||||
registrar = ModelRouteRegistrar(app)
|
||||
handler_lookup = {
|
||||
definition.handler_name: self._make_handler_proxy(definition.handler_name)
|
||||
for definition in COMMON_ROUTE_DEFINITIONS
|
||||
}
|
||||
registrar.register_common_routes(prefix, handler_lookup)
|
||||
self.setup_specific_routes(registrar, prefix)
|
||||
|
||||
@abstractmethod
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str) -> None:
|
||||
"""Setup model-specific routes."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse model-specific parameters - to be overridden by subclasses."""
|
||||
return {}
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type - to be overridden by subclasses."""
|
||||
return True
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages - to be overridden by subclasses."""
|
||||
return "any model type"
|
||||
|
||||
def _find_model_file(self, files):
|
||||
"""Find the appropriate model file from the files list - can be overridden by subclasses."""
|
||||
return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None)
|
||||
|
||||
def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
"""Expose handlers for subclasses or tests."""
|
||||
return self._ensure_handler_mapping()[name]
|
||||
|
||||
def _ensure_service(self):
|
||||
if self.service is None:
|
||||
raise RuntimeError("Model service has not been attached")
|
||||
return self.service
|
||||
|
||||
def _ensure_file_service(self) -> ModelFileService:
|
||||
if self.model_file_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
return self.model_file_service
|
||||
|
||||
def _ensure_move_service(self) -> ModelMoveService:
|
||||
if self.model_move_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
return self.model_move_service
|
||||
|
||||
def _ensure_lifecycle_service(self) -> ModelLifecycleService:
|
||||
if self.model_lifecycle_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
return self.model_lifecycle_service
|
||||
|
||||
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
handler = self.get_handler(name)
|
||||
except RuntimeError:
|
||||
return web.json_response({"success": False, "error": "Service not ready"}, status=503)
|
||||
return await handler(request)
|
||||
|
||||
return proxy
|
||||
|
||||
def _ensure_model_update_service(self) -> "ModelUpdateService":
|
||||
if self._model_update_service is None:
|
||||
raise RuntimeError("Model update service has not been attached")
|
||||
return self._model_update_service
|
||||
218
py/routes/base_recipe_routes.py
Normal file
218
py/routes/base_recipe_routes.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Base infrastructure shared across recipe routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
)
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
RecipeAnalysisHandler,
|
||||
RecipeHandlerSet,
|
||||
RecipeListingHandler,
|
||||
RecipeManagementHandler,
|
||||
RecipePageView,
|
||||
RecipeQueryHandler,
|
||||
RecipeSharingHandler,
|
||||
)
|
||||
from .recipe_route_registrar import ROUTE_DEFINITIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRecipeRoutes:
|
||||
"""Common dependency and startup wiring for recipe routes."""
|
||||
|
||||
_HANDLER_NAMES: tuple[str, ...] = tuple(
|
||||
definition.handler_name for definition in ROUTE_DEFINITIONS
|
||||
)
|
||||
|
||||
template_name: str = "recipes.html"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.recipe_scanner = None
|
||||
self.lora_scanner = None
|
||||
self.civitai_client = None
|
||||
self.settings = get_settings_manager()
|
||||
self.server_i18n = server_i18n
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self._i18n_registered = False
|
||||
self._startup_hooks_registered = False
|
||||
self._handler_set: RecipeHandlerSet | None = None
|
||||
self._handler_mapping: dict[str, Callable] | None = None
|
||||
|
||||
async def attach_dependencies(self, app: web.Application | None = None) -> None:
|
||||
"""Resolve shared services from the registry."""
|
||||
|
||||
await self._ensure_services()
|
||||
self._ensure_i18n_filter()
|
||||
|
||||
async def ensure_dependencies_ready(self) -> None:
|
||||
"""Ensure dependencies are available for request handlers."""
|
||||
|
||||
if self.recipe_scanner is None or self.civitai_client is None:
|
||||
await self.attach_dependencies()
|
||||
|
||||
def register_startup_hooks(self, app: web.Application) -> None:
|
||||
"""Register startup hooks once for dependency wiring."""
|
||||
|
||||
if self._startup_hooks_registered:
|
||||
return
|
||||
|
||||
app.on_startup.append(self.attach_dependencies)
|
||||
app.on_startup.append(self.prewarm_cache)
|
||||
self._startup_hooks_registered = True
|
||||
|
||||
async def prewarm_cache(self, app: web.Application | None = None) -> None:
|
||||
"""Pre-load recipe and LoRA caches on startup."""
|
||||
|
||||
try:
|
||||
await self.attach_dependencies(app)
|
||||
|
||||
if self.lora_scanner is not None:
|
||||
await self.lora_scanner.get_cached_data()
|
||||
hash_index = getattr(self.lora_scanner, "_hash_index", None)
|
||||
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
||||
_ = len(hash_index._hash_to_path)
|
||||
|
||||
if self.recipe_scanner is not None:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as exc:
|
||||
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable]:
|
||||
"""Return a mapping of handler name to coroutine for registrar binding."""
|
||||
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
async def _ensure_services(self) -> None:
|
||||
if self.recipe_scanner is None:
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None)
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
def _ensure_i18n_filter(self) -> None:
|
||||
if not self._i18n_registered:
|
||||
self.template_env.filters["t"] = self.server_i18n.create_template_filter()
|
||||
self._i18n_registered = True
|
||||
|
||||
def get_handler_owner(self):
|
||||
"""Return the object supplying bound handler coroutines."""
|
||||
|
||||
if self._handler_set is None:
|
||||
self._handler_set = self._create_handler_set()
|
||||
return self._handler_set
|
||||
|
||||
def _create_handler_set(self) -> RecipeHandlerSet:
|
||||
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||
civitai_client_getter = lambda: self.civitai_client
|
||||
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||
MetadataProcessor,
|
||||
)
|
||||
from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found]
|
||||
MetadataRegistry,
|
||||
)
|
||||
else: # pragma: no cover - optional dependency path
|
||||
get_metadata = None # type: ignore[assignment]
|
||||
MetadataProcessor = None # type: ignore[assignment]
|
||||
MetadataRegistry = None # type: ignore[assignment]
|
||||
|
||||
analysis_service = RecipeAnalysisService(
|
||||
exif_utils=ExifUtils,
|
||||
recipe_parser_factory=RecipeParserFactory,
|
||||
downloader_factory=get_downloader,
|
||||
metadata_collector=get_metadata,
|
||||
metadata_processor_cls=MetadataProcessor,
|
||||
metadata_registry_cls=MetadataRegistry,
|
||||
standalone_mode=standalone_mode,
|
||||
logger=logger,
|
||||
)
|
||||
persistence_service = RecipePersistenceService(
|
||||
exif_utils=ExifUtils,
|
||||
card_preview_width=CARD_PREVIEW_WIDTH,
|
||||
logger=logger,
|
||||
)
|
||||
sharing_service = RecipeSharingService(logger=logger)
|
||||
|
||||
page_view = RecipePageView(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
settings_service=self.settings,
|
||||
server_i18n=self.server_i18n,
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
listing = RecipeListingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
query = RecipeQueryHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
format_recipe_file_url=listing.format_recipe_file_url,
|
||||
logger=logger,
|
||||
)
|
||||
management = RecipeManagementHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
persistence_service=persistence_service,
|
||||
analysis_service=analysis_service,
|
||||
downloader_factory=get_downloader,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
)
|
||||
analysis = RecipeAnalysisHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
logger=logger,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
sharing = RecipeSharingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
sharing_service=sharing_service,
|
||||
)
|
||||
|
||||
return RecipeHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
query=query,
|
||||
management=management,
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
)
|
||||
112
py/routes/checkpoint_routes.py
Normal file
112
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
from typing import Dict
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Checkpoint-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||
super().__init__()
|
||||
self.template_name = "checkpoints.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.service = CheckpointService(checkpoint_scanner, update_service=update_service)
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Checkpoint routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||
super().setup_routes(app, 'checkpoints')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Checkpoint-specific routes"""
|
||||
# Checkpoint info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info)
|
||||
|
||||
# Checkpoint roots and Unet roots
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Checkpoint"""
|
||||
return model_type.lower() == 'checkpoint'
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "Checkpoint"
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse Checkpoint-specific parameters"""
|
||||
params: Dict = {}
|
||||
|
||||
if 'checkpoint_hash' in request.query:
|
||||
params['hash_filters'] = {'single_hash': request.query['checkpoint_hash'].lower()}
|
||||
elif 'checkpoint_hashes' in request.query:
|
||||
params['hash_filters'] = {
|
||||
'multiple_hashes': [h.lower() for h in request.query['checkpoint_hashes'].split(',')]
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of checkpoint roots from config"""
|
||||
try:
|
||||
roots = config.checkpoints_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_unet_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of unet roots from config"""
|
||||
try:
|
||||
roots = config.unet_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
@@ -1,678 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointsRoutes:
|
||||
"""API routes for checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||
|
||||
# Add new WebSocket endpoint for checkpoint progress
|
||||
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
|
||||
|
||||
async def get_checkpoints(self, request):
|
||||
"""Get paginated checkpoint data"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
|
||||
# Process search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Process hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Get data from scanner
|
||||
result = await self.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy_search=fuzzy_search,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
search_options=search_options,
|
||||
hash_filters=hash_filters
|
||||
)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
# Return as JSON
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||
folder=None, search=None, fuzzy_search=False,
|
||||
base_models=None, tags=None,
|
||||
search_options=None, hash_filters=None):
|
||||
"""Get paginated and filtered checkpoint data"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if any(tag in cp.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
|
||||
for cp in filtered_data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('file_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('file_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('model_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('model_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in cp:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _format_checkpoint_response(self, checkpoint):
|
||||
"""Format checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint["model_name"],
|
||||
"file_name": checkpoint["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint.get("base_model", ""),
|
||||
"folder": checkpoint["folder"],
|
||||
"sha256": checkpoint.get("sha256", ""),
|
||||
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint.get("size", 0),
|
||||
"modified": checkpoint.get("modified", ""),
|
||||
"tags": checkpoint.get("tags", []),
|
||||
"modelDescription": checkpoint.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint.get("from_civitai", True),
|
||||
"notes": checkpoint.get("notes", ""),
|
||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare checkpoints to process
|
||||
to_process = [
|
||||
cp for cp in cache.raw_data
|
||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each checkpoint
|
||||
for cp in to_process:
|
||||
try:
|
||||
original_name = cp.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=cp['sha256'],
|
||||
file_path=cp['file_path'],
|
||||
model_data=cp,
|
||||
update_cache_func=self.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != cp.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': cp.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get base models
|
||||
base_models = await self.scanner.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_checkpoints(self, request):
|
||||
"""Force a rescan of checkpoint files"""
|
||||
try:
|
||||
await self.scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoint_info(self, request):
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /checkpoints request"""
|
||||
try:
|
||||
# Check if the CheckpointScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
|
||||
logger.info("Checkpoints page is initializing, returning loading page")
|
||||
else:
|
||||
# 正常流程 - 获取已经初始化好的缓存数据
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||
# 如果获取缓存失败,也显示初始化页面
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Checkpoints cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading checkpoints page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
||||
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement for checkpoints"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint download request"""
|
||||
async with self._download_lock:
|
||||
# Get the download manager from service registry if not already initialized
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback that uses checkpoint-specific WebSocket
|
||||
async def progress_callback(progress):
|
||||
await ws_manager.broadcast_checkpoint_progress({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=data.get('checkpoint_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
progress_callback=progress_callback,
|
||||
model_type="checkpoint"
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
||||
)
|
||||
|
||||
logger.error(f"Error downloading checkpoint: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates for checkpoints"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Update metadata
|
||||
metadata.update(metadata_updates)
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get the civitai client from service registry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if model_type.lower() != 'checkpoint':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
63
py/routes/embedding_routes.py
Normal file
63
py/routes/embedding_routes.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Embedding-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Embedding routes with Embedding service"""
|
||||
super().__init__()
|
||||
self.template_name = "embeddings.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.service = EmbeddingService(embedding_scanner, update_service=update_service)
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Embedding routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'embeddings' prefix (includes page route)
|
||||
super().setup_routes(app, 'embeddings')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Embedding-specific routes"""
|
||||
# Embedding info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_embedding_info)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Embedding"""
|
||||
return model_type.lower() == 'textualinversion'
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "TextualInversion"
|
||||
|
||||
async def get_embedding_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific embedding by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
embedding_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if embedding_info:
|
||||
return web.json_response(embedding_info)
|
||||
else:
|
||||
return web.json_response({"error": "Embedding not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
63
py/routes/example_images_route_registrar.py
Normal file
63
py/routes/example_images_route_registrar.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Route registrar for example image endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative configuration for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"),
|
||||
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
|
||||
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/stop-example-images", "stop_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
|
||||
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
|
||||
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"),
|
||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||
)
|
||||
|
||||
|
||||
class ExampleImagesRouteRegistrar:
|
||||
"""Bind declarative example image routes to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(
|
||||
self,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
"""Register each route definition using the supplied handlers."""
|
||||
|
||||
for definition in definitions:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
88
py/routes/example_images_routes.py
Normal file
88
py/routes/example_images_routes.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .example_images_route_registrar import ExampleImagesRouteRegistrar
|
||||
from .handlers.example_images_handlers import (
|
||||
ExampleImagesDownloadHandler,
|
||||
ExampleImagesFileHandler,
|
||||
ExampleImagesHandlerSet,
|
||||
ExampleImagesManagementHandler,
|
||||
)
|
||||
from ..services.use_cases.example_images import (
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
)
|
||||
from ..utils.example_images_download_manager import (
|
||||
DownloadManager,
|
||||
get_default_download_manager,
|
||||
)
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
from ..services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleImagesRoutes:
|
||||
"""Route controller for example image endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager: DownloadManager | None = None,
|
||||
processor=ExampleImagesProcessor,
|
||||
file_manager=ExampleImagesFileManager,
|
||||
cleanup_service: ExampleImagesCleanupService | None = None,
|
||||
) -> None:
|
||||
if ws_manager is None:
|
||||
raise ValueError("ws_manager is required")
|
||||
self._download_manager = download_manager or get_default_download_manager(ws_manager)
|
||||
self._processor = processor
|
||||
self._file_manager = file_manager
|
||||
self._cleanup_service = cleanup_service or ExampleImagesCleanupService()
|
||||
self._handler_set: ExampleImagesHandlerSet | None = None
|
||||
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
|
||||
"""Register routes on the given aiohttp application using default wiring."""
|
||||
|
||||
controller = cls(ws_manager=ws_manager)
|
||||
controller.register(app)
|
||||
|
||||
def register(self, app: web.Application) -> None:
|
||||
"""Bind the controller's handlers to the aiohttp router."""
|
||||
|
||||
registrar = ExampleImagesRouteRegistrar(app)
|
||||
registrar.register_routes(self.to_route_mapping())
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Return the registrar-compatible mapping of handler names to callables."""
|
||||
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._build_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
def _build_handler_set(self) -> ExampleImagesHandlerSet:
|
||||
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
|
||||
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
|
||||
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
|
||||
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
|
||||
management_handler = ExampleImagesManagementHandler(
|
||||
import_use_case,
|
||||
self._processor,
|
||||
self._cleanup_service,
|
||||
)
|
||||
file_handler = ExampleImagesFileHandler(self._file_manager)
|
||||
return ExampleImagesHandlerSet(
|
||||
download=download_handler,
|
||||
management=management_handler,
|
||||
files=file_handler,
|
||||
)
|
||||
167
py/routes/handlers/example_images_handlers.py
Normal file
167
py/routes/handlers/example_images_handlers.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Handler set for example image routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.use_cases.example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from ...utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
DownloadNotRunningError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from ...utils.example_images_processor import ExampleImagesImportError
|
||||
|
||||
|
||||
class ExampleImagesDownloadHandler:
|
||||
"""HTTP adapters for download-related example image endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
download_use_case: DownloadExampleImagesUseCase,
|
||||
download_manager,
|
||||
) -> None:
|
||||
self._download_use_case = download_use_case
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_use_case.execute(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadExampleImagesInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadExampleImagesConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._download_manager.get_status(request)
|
||||
return web.json_response(result)
|
||||
|
||||
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.pause_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.resume_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def stop_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.stop_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_manager.start_force_download(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress_snapshot,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
|
||||
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor, cleanup_service) -> None:
|
||||
self._import_use_case = import_use_case
|
||||
self._processor = processor
|
||||
self._cleanup_service = cleanup_service
|
||||
|
||||
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._import_use_case.execute(request)
|
||||
return web.json_response(result)
|
||||
except ImportExampleImagesValidationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesImportError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._processor.delete_custom_image(request)
|
||||
|
||||
async def cleanup_example_image_folders(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._cleanup_service.cleanup_example_image_folders()
|
||||
|
||||
if result.get('success') or result.get('partial_success'):
|
||||
return web.json_response(result, status=200)
|
||||
|
||||
error_code = result.get('error_code')
|
||||
status = 400 if error_code in {'path_not_configured', 'path_not_found'} else 500
|
||||
return web.json_response(result, status=status)
|
||||
|
||||
|
||||
class ExampleImagesFileHandler:
|
||||
"""HTTP adapters for filesystem-centric endpoints."""
|
||||
|
||||
def __init__(self, file_manager) -> None:
|
||||
self._file_manager = file_manager
|
||||
|
||||
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.open_folder(request)
|
||||
|
||||
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.get_files(request)
|
||||
|
||||
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.has_images(request)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExampleImagesHandlerSet:
|
||||
"""Aggregate of handlers exposed to the registrar."""
|
||||
|
||||
download: ExampleImagesDownloadHandler
|
||||
management: ExampleImagesManagementHandler
|
||||
files: ExampleImagesFileHandler
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Flatten handler methods into the registrar mapping."""
|
||||
|
||||
return {
|
||||
"download_example_images": self.download.download_example_images,
|
||||
"get_example_images_status": self.download.get_example_images_status,
|
||||
"pause_example_images": self.download.pause_example_images,
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"stop_example_images": self.download.stop_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
"cleanup_example_image_folders": self.management.cleanup_example_image_folders,
|
||||
"open_example_images_folder": self.files.open_example_images_folder,
|
||||
"get_example_image_files": self.files.get_example_image_files,
|
||||
"has_example_images": self.files.has_example_images,
|
||||
}
|
||||
1114
py/routes/handlers/misc_handlers.py
Normal file
1114
py/routes/handlers/misc_handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
1636
py/routes/handlers/model_handlers.py
Normal file
1636
py/routes/handlers/model_handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
56
py/routes/handlers/preview_handlers.py
Normal file
56
py/routes/handlers/preview_handlers.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Handlers responsible for serving preview assets dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config as global_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewHandler:
|
||||
"""Serve preview assets for the active library at request time."""
|
||||
|
||||
def __init__(self, *, config=global_config) -> None:
|
||||
self._config = config
|
||||
|
||||
async def serve_preview(self, request: web.Request) -> web.StreamResponse:
|
||||
"""Return the preview file referenced by the encoded ``path`` query."""
|
||||
|
||||
raw_path = request.query.get("path", "")
|
||||
if not raw_path:
|
||||
raise web.HTTPBadRequest(text="Missing 'path' query parameter")
|
||||
|
||||
try:
|
||||
decoded_path = urllib.parse.unquote(raw_path)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to decode preview path %s: %s", raw_path, exc)
|
||||
raise web.HTTPBadRequest(text="Invalid preview path encoding") from exc
|
||||
|
||||
normalized = decoded_path.replace("\\", "/")
|
||||
candidate = Path(normalized)
|
||||
try:
|
||||
resolved = candidate.expanduser().resolve(strict=False)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to resolve preview path %s: %s", normalized, exc)
|
||||
raise web.HTTPBadRequest(text="Unable to resolve preview path") from exc
|
||||
|
||||
resolved_str = str(resolved)
|
||||
if not self._config.is_preview_path_allowed(resolved_str):
|
||||
logger.debug("Rejected preview outside allowed roots: %s", resolved_str)
|
||||
raise web.HTTPForbidden(text="Preview path is not within an allowed directory")
|
||||
|
||||
if not resolved.is_file():
|
||||
logger.debug("Preview file not found at %s", resolved_str)
|
||||
raise web.HTTPNotFound(text="Preview file not found")
|
||||
|
||||
# aiohttp's FileResponse handles range requests and content headers for us.
|
||||
return web.FileResponse(path=resolved, chunk_size=256 * 1024)
|
||||
|
||||
|
||||
__all__ = ["PreviewHandler"]
|
||||
940
py/routes/handlers/recipe_handlers.py
Normal file
940
py/routes/handlers/recipe_handlers.py
Normal file
@@ -0,0 +1,940 @@
|
||||
"""Dedicated handler objects for recipe-related routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.server_i18n import server_i18n as default_server_i18n
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
RecipeValidationError,
|
||||
)
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
Logger = logging.Logger
|
||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||
RecipeScannerGetter = Callable[[], Any]
|
||||
CivitaiClientGetter = Callable[[], Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RecipeHandlerSet:
|
||||
"""Group of handlers providing recipe route implementations."""
|
||||
|
||||
page_view: "RecipePageView"
|
||||
listing: "RecipeListingHandler"
|
||||
query: "RecipeQueryHandler"
|
||||
management: "RecipeManagementHandler"
|
||||
analysis: "RecipeAnalysisHandler"
|
||||
sharing: "RecipeSharingHandler"
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
"""Expose handler coroutines keyed by registrar handler names."""
|
||||
|
||||
return {
|
||||
"render_page": self.page_view.render_page,
|
||||
"list_recipes": self.listing.list_recipes,
|
||||
"get_recipe": self.listing.get_recipe,
|
||||
"import_remote_recipe": self.management.import_remote_recipe,
|
||||
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||
"analyze_local_image": self.analysis.analyze_local_image,
|
||||
"save_recipe": self.management.save_recipe,
|
||||
"delete_recipe": self.management.delete_recipe,
|
||||
"get_top_tags": self.query.get_top_tags,
|
||||
"get_base_models": self.query.get_base_models,
|
||||
"share_recipe": self.sharing.share_recipe,
|
||||
"download_shared_recipe": self.sharing.download_shared_recipe,
|
||||
"get_recipe_syntax": self.query.get_recipe_syntax,
|
||||
"update_recipe": self.management.update_recipe,
|
||||
"reconnect_lora": self.management.reconnect_lora,
|
||||
"find_duplicates": self.query.find_duplicates,
|
||||
"bulk_delete": self.management.bulk_delete,
|
||||
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||
"scan_recipes": self.query.scan_recipes,
|
||||
}
|
||||
|
||||
|
||||
class RecipePageView:
|
||||
"""Render the recipe shell page."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
settings_service: SettingsManager,
|
||||
server_i18n=default_server_i18n,
|
||||
template_env,
|
||||
template_name: str,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._settings = settings_service
|
||||
self._server_i18n = server_i18n
|
||||
self._template_env = template_env
|
||||
self._template_name = template_name
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def render_page(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None: # pragma: no cover - defensive guard
|
||||
raise RuntimeError("Recipe scanner not available")
|
||||
|
||||
user_language = self._settings.get("language", "en")
|
||||
self._server_i18n.set_locale(user_language)
|
||||
|
||||
try:
|
||||
await recipe_scanner.get_cached_data(force_refresh=False)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
recipes=[],
|
||||
is_initializing=False,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
except Exception as cache_error: # pragma: no cover - logging path
|
||||
self._logger.error("Error loading recipe cache data: %s", cache_error)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
is_initializing=True,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
return web.Response(text=rendered, content_type="text/html")
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error handling recipes request: %s", exc, exc_info=True)
|
||||
return web.Response(text="Error loading recipes page", status=500)
|
||||
|
||||
|
||||
class RecipeListingHandler:
|
||||
"""Provide listing and detail APIs for recipes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def list_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
page = int(request.query.get("page", "1"))
|
||||
page_size = int(request.query.get("page_size", "20"))
|
||||
sort_by = request.query.get("sort_by", "date")
|
||||
search = request.query.get("search")
|
||||
|
||||
search_options = {
|
||||
"title": request.query.get("search_title", "true").lower() == "true",
|
||||
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
}
|
||||
|
||||
filters: Dict[str, Any] = {}
|
||||
base_models = request.query.get("base_models")
|
||||
if base_models:
|
||||
filters["base_model"] = base_models.split(",")
|
||||
|
||||
tag_filters: Dict[str, str] = {}
|
||||
legacy_tags = request.query.get("tags")
|
||||
if legacy_tags:
|
||||
for tag in legacy_tags.split(","):
|
||||
tag = tag.strip()
|
||||
if tag:
|
||||
tag_filters[tag] = "include"
|
||||
|
||||
include_tags = request.query.getall("tag_include", [])
|
||||
for tag in include_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "include"
|
||||
|
||||
exclude_tags = request.query.getall("tag_exclude", [])
|
||||
for tag in exclude_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
if tag_filters:
|
||||
filters["tags"] = tag_filters
|
||||
|
||||
lora_hash = request.query.get("lora_hash")
|
||||
|
||||
result = await recipe_scanner.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
search=search,
|
||||
filters=filters,
|
||||
search_options=search_options,
|
||||
lora_hash=lora_hash,
|
||||
)
|
||||
|
||||
for item in result.get("items", []):
|
||||
file_path = item.get("file_path")
|
||||
if file_path:
|
||||
item["file_url"] = self.format_recipe_file_url(file_path)
|
||||
else:
|
||||
item.setdefault("file_url", "/loras_static/images/no-preview.png")
|
||||
item.setdefault("loras", [])
|
||||
item.setdefault("base_model", "")
|
||||
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
|
||||
if not recipe:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
return web.json_response(recipe)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def format_recipe_file_url(self, file_path: str) -> str:
|
||||
try:
|
||||
normalized_path = os.path.normpath(file_path)
|
||||
static_url = config.get_preview_static_url(normalized_path)
|
||||
if static_url:
|
||||
return static_url
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
|
||||
class RecipeQueryHandler:
|
||||
"""Provide read-only insights on recipe data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
format_recipe_file_url: Callable[[str], str],
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._format_recipe_file_url = format_recipe_file_url
|
||||
self._logger = logger
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
tag_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
for tag in recipe.get("tags", []) or []:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
|
||||
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving top tags: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
base_model = recipe.get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
lora_hash = request.query.get("hash")
|
||||
if not lora_hash:
|
||||
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||
|
||||
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
self._logger.info("Manually triggering recipe cache rebuild")
|
||||
await recipe_scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def find_duplicates(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
duplicate_groups = await recipe_scanner.find_all_duplicate_recipes()
|
||||
response_data = []
|
||||
|
||||
for fingerprint, recipe_ids in duplicate_groups.items():
|
||||
if len(recipe_ids) <= 1:
|
||||
continue
|
||||
|
||||
recipes = []
|
||||
for recipe_id in recipe_ids:
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if recipe:
|
||||
recipes.append(
|
||||
{
|
||||
"id": recipe.get("id"),
|
||||
"title": recipe.get("title"),
|
||||
"file_url": recipe.get("file_url")
|
||||
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||
"modified": recipe.get("modified"),
|
||||
"created_date": recipe.get("created_date"),
|
||||
"lora_count": len(recipe.get("loras", [])),
|
||||
}
|
||||
)
|
||||
|
||||
if len(recipes) >= 2:
|
||||
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||
response_data.append(
|
||||
{
|
||||
"fingerprint": fingerprint,
|
||||
"count": len(recipes),
|
||||
"recipes": recipes,
|
||||
}
|
||||
)
|
||||
|
||||
response_data.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "duplicate_groups": response_data})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
try:
|
||||
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
|
||||
except RecipeNotFoundError:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
if not syntax_parts:
|
||||
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||
|
||||
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class RecipeManagementHandler:
|
||||
"""Handle create/update/delete style recipe operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
persistence_service: RecipePersistenceService,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
downloader_factory,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._persistence_service = persistence_service
|
||||
self._analysis_service = analysis_service
|
||||
self._downloader_factory = downloader_factory
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
reader = await request.multipart()
|
||||
payload = await self._parse_save_payload(reader)
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=payload["image_bytes"],
|
||||
image_base64=payload["image_base64"],
|
||||
name=payload["name"],
|
||||
tags=payload["tags"],
|
||||
metadata=payload["metadata"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def import_remote_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
params = request.rel_url.query
|
||||
image_url = params.get("image_url")
|
||||
name = params.get("name")
|
||||
resources_raw = params.get("resources")
|
||||
if not image_url:
|
||||
raise RecipeValidationError("Missing required field: image_url")
|
||||
if not name:
|
||||
raise RecipeValidationError("Missing required field: name")
|
||||
if not resources_raw:
|
||||
raise RecipeValidationError("Missing required field: resources")
|
||||
|
||||
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
|
||||
gen_params = self._parse_gen_params(params.get("gen_params"))
|
||||
metadata: Dict[str, Any] = {
|
||||
"base_model": params.get("base_model", "") or "",
|
||||
"loras": lora_entries,
|
||||
}
|
||||
source_path = params.get("source_path")
|
||||
if source_path:
|
||||
metadata["source_path"] = source_path
|
||||
if gen_params is not None:
|
||||
metadata["gen_params"] = gen_params
|
||||
if checkpoint_entry:
|
||||
metadata["checkpoint"] = checkpoint_entry
|
||||
gen_params_ref = metadata.setdefault("gen_params", {})
|
||||
if "checkpoint" not in gen_params_ref:
|
||||
gen_params_ref["checkpoint"] = checkpoint_entry
|
||||
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
if base_model_from_metadata:
|
||||
metadata["base_model"] = base_model_from_metadata
|
||||
|
||||
tags = self._parse_tags(params.get("tags"))
|
||||
image_bytes = await self._download_image_bytes(image_url)
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=image_bytes,
|
||||
image_base64=None,
|
||||
name=name,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._persistence_service.delete_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error deleting recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def update_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
data = await request.json()
|
||||
result = await self._persistence_service.update_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
for field in ("recipe_id", "lora_index", "target_name"):
|
||||
if field not in data:
|
||||
raise RecipeValidationError(f"Missing required field: {field}")
|
||||
|
||||
result = await self._persistence_service.reconnect_lora(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=data["recipe_id"],
|
||||
lora_index=int(data["lora_index"]),
|
||||
target_name=data["target_name"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def bulk_delete(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
recipe_ids = data.get("recipe_ids", [])
|
||||
result = await self._persistence_service.bulk_delete(
|
||||
recipe_scanner=recipe_scanner, recipe_ids=recipe_ids
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error performing bulk delete: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
analysis = await self._analysis_service.analyze_widget_metadata(
|
||||
recipe_scanner=recipe_scanner
|
||||
)
|
||||
metadata = analysis.payload.get("metadata")
|
||||
image_bytes = analysis.payload.get("image_bytes")
|
||||
if not metadata or image_bytes is None:
|
||||
raise RecipeValidationError("Unable to extract metadata from widget")
|
||||
|
||||
result = await self._persistence_service.save_recipe_from_widget(
|
||||
recipe_scanner=recipe_scanner,
|
||||
metadata=metadata,
|
||||
image_bytes=image_bytes,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def _parse_save_payload(self, reader) -> dict[str, Any]:
|
||||
image_bytes: Optional[bytes] = None
|
||||
image_base64: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = []
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
if field.name == "image":
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
image_bytes = bytes(image_chunks)
|
||||
elif field.name == "image_base64":
|
||||
image_base64 = await field.text()
|
||||
elif field.name == "name":
|
||||
name = await field.text()
|
||||
elif field.name == "tags":
|
||||
tags_text = await field.text()
|
||||
try:
|
||||
parsed_tags = json.loads(tags_text)
|
||||
tags = parsed_tags if isinstance(parsed_tags, list) else []
|
||||
except Exception:
|
||||
tags = []
|
||||
elif field.name == "metadata":
|
||||
metadata_text = await field.text()
|
||||
try:
|
||||
metadata = json.loads(metadata_text)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
return {
|
||||
"image_bytes": image_bytes,
|
||||
"image_base64": image_base64,
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
def _parse_tags(self, tag_text: Optional[str]) -> list[str]:
|
||||
if not tag_text:
|
||||
return []
|
||||
return [tag.strip() for tag in tag_text.split(",") if tag.strip()]
|
||||
|
||||
def _parse_gen_params(self, payload: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
if payload is None:
|
||||
return None
|
||||
if payload == "":
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RecipeValidationError(f"Invalid gen_params payload: {exc}") from exc
|
||||
if parsed is None:
|
||||
return {}
|
||||
if not isinstance(parsed, dict):
|
||||
raise RecipeValidationError("gen_params payload must be an object")
|
||||
return parsed
|
||||
|
||||
def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
try:
|
||||
payload = json.loads(payload_raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RecipeValidationError(f"Invalid resources payload: {exc}") from exc
|
||||
|
||||
if not isinstance(payload, list):
|
||||
raise RecipeValidationError("Resources payload must be a list")
|
||||
|
||||
checkpoint_entry: Optional[Dict[str, Any]] = None
|
||||
lora_entries: List[Dict[str, Any]] = []
|
||||
|
||||
for resource in payload:
|
||||
if not isinstance(resource, dict):
|
||||
continue
|
||||
resource_type = str(resource.get("type") or "").lower()
|
||||
if resource_type == "checkpoint":
|
||||
checkpoint_entry = self._build_checkpoint_entry(resource)
|
||||
elif resource_type in {"lora", "lycoris"}:
|
||||
lora_entries.append(self._build_lora_entry(resource))
|
||||
|
||||
return checkpoint_entry, lora_entries
|
||||
|
||||
def _build_checkpoint_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": resource.get("type", "checkpoint"),
|
||||
"modelId": self._safe_int(resource.get("modelId")),
|
||||
"modelVersionId": self._safe_int(resource.get("modelVersionId")),
|
||||
"modelName": resource.get("modelName", ""),
|
||||
"modelVersionName": resource.get("modelVersionName", ""),
|
||||
}
|
||||
|
||||
def _build_lora_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||
weight_raw = resource.get("weight", 1.0)
|
||||
try:
|
||||
weight = float(weight_raw)
|
||||
except (TypeError, ValueError):
|
||||
weight = 1.0
|
||||
return {
|
||||
"file_name": resource.get("modelName", ""),
|
||||
"weight": weight,
|
||||
"id": self._safe_int(resource.get("modelVersionId")),
|
||||
"name": resource.get("modelName", ""),
|
||||
"version": resource.get("modelVersionName", ""),
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
|
||||
async def _download_image_bytes(self, image_url: str) -> bytes:
|
||||
civitai_client = self._civitai_client_getter()
|
||||
downloader = await self._downloader_factory()
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
download_url = image_url
|
||||
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url)
|
||||
if civitai_match:
|
||||
if civitai_client is None:
|
||||
raise RecipeDownloadError("Civitai client unavailable for image download")
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
download_url = image_info.get("url")
|
||||
if not download_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
|
||||
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
|
||||
if not success:
|
||||
raise RecipeDownloadError(f"Failed to download image: {result}")
|
||||
with open(temp_path, "rb") as file_obj:
|
||||
return file_obj.read()
|
||||
except RecipeDownloadError:
|
||||
raise
|
||||
except RecipeValidationError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
raise RecipeValidationError(f"Unable to download image: {exc}") from exc
|
||||
finally:
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _safe_int(self, value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
async def _resolve_base_model_from_checkpoint(self, checkpoint_entry: Dict[str, Any]) -> str:
|
||||
version_id = self._safe_int(checkpoint_entry.get("modelVersionId"))
|
||||
|
||||
if not version_id:
|
||||
return ""
|
||||
|
||||
try:
|
||||
provider = await get_default_metadata_provider()
|
||||
if not provider:
|
||||
return ""
|
||||
|
||||
version_info = await provider.get_model_version_info(version_id)
|
||||
if isinstance(version_info, tuple):
|
||||
version_info = version_info[0]
|
||||
|
||||
if isinstance(version_info, dict):
|
||||
base_model = version_info.get("baseModel") or ""
|
||||
return str(base_model) if base_model is not None else ""
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.warning("Failed to resolve base model from checkpoint metadata: %s", exc)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class RecipeAnalysisHandler:
|
||||
"""Analyze images to extract recipe metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
logger: Logger,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
self._logger = logger
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def analyze_uploaded_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
civitai_client = self._civitai_client_getter()
|
||||
if recipe_scanner is None or civitai_client is None:
|
||||
raise RuntimeError("Required services unavailable")
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "multipart/form-data" in content_type:
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "image":
|
||||
raise RecipeValidationError("No image field found")
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
result = await self._analysis_service.analyze_uploaded_image(
|
||||
image_bytes=bytes(image_chunks),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_remote_image(
|
||||
url=data.get("url"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
civitai_client=civitai_client,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
raise RecipeValidationError("Unsupported content type")
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
async def analyze_local_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_local_image(
|
||||
file_path=data.get("path"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing local image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
|
||||
class RecipeSharingHandler:
|
||||
"""Serve endpoints related to recipe sharing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
sharing_service: RecipeSharingService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._sharing_service = sharing_service
|
||||
|
||||
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._sharing_service.share_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error sharing recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
download_info = await self._sharing_service.prepare_download(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.FileResponse(
|
||||
download_info.file_path,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{download_info.download_filename}"'
|
||||
},
|
||||
)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
@@ -1,189 +1,263 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import logging
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.scanner = None
|
||||
self.recipe_scanner = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
super().__init__()
|
||||
self.template_name = "loras.html"
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the LoraScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info("Loras page is initializing, returning loading page")
|
||||
else:
|
||||
# Normal flow - get data from initialized cache
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
||||
try:
|
||||
# Return the file URL directly for the first lora root's preview
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
# If not in recipes dir, try to create a valid URL from the file path
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
update_service = await ServiceRegistry.get_model_update_service()
|
||||
self.service = LoraService(lora_scanner, update_service=update_service)
|
||||
self.set_model_update_service(update_service)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||
|
||||
# ComfyUI integration
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
# Register routes
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
return params
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for LoRA"""
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
return model_type.lower() in VALID_LORA_TYPES
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "LORA, LoCon, or DORA"
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
try:
|
||||
relative_path = request.query.get('relative_path')
|
||||
if not relative_path:
|
||||
return web.Response(text='Relative path is required', status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'usage_tips': usage_tips or ''
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for entry in node_ids:
|
||||
node_identifier = entry
|
||||
graph_identifier = None
|
||||
if isinstance(entry, dict):
|
||||
node_identifier = entry.get("node_id")
|
||||
graph_identifier = entry.get("graph_id")
|
||||
|
||||
try:
|
||||
parsed_node_id = int(node_identifier)
|
||||
except (TypeError, ValueError):
|
||||
parsed_node_id = node_identifier
|
||||
|
||||
payload = {
|
||||
"id": parsed_node_id,
|
||||
"message": trigger_words_text
|
||||
}
|
||||
|
||||
if graph_identifier is not None:
|
||||
payload["graph_id"] = str(graph_identifier)
|
||||
|
||||
PromptServer.instance.send_sync("trigger_word_update", payload)
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
72
py/routes/misc_route_registrar.py
Normal file
72
py/routes/misc_route_registrar.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Route registrar for miscellaneous endpoints.
|
||||
|
||||
This module mirrors the model route registrar architecture so that
|
||||
miscellaneous endpoints share a consistent registration flow.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/settings", "get_settings"),
|
||||
RouteDefinition("POST", "/api/lm/settings", "update_settings"),
|
||||
RouteDefinition("GET", "/api/lm/priority-tags", "get_priority_tags"),
|
||||
RouteDefinition("GET", "/api/lm/settings/libraries", "get_settings_libraries"),
|
||||
RouteDefinition("POST", "/api/lm/settings/libraries/activate", "activate_library"),
|
||||
RouteDefinition("GET", "/api/lm/health-check", "health_check"),
|
||||
RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"),
|
||||
RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"),
|
||||
RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"),
|
||||
RouteDefinition("POST", "/api/lm/update-lora-code", "update_lora_code"),
|
||||
RouteDefinition("GET", "/api/lm/trained-words", "get_trained_words"),
|
||||
RouteDefinition("GET", "/api/lm/model-example-files", "get_model_example_files"),
|
||||
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
||||
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
||||
)
|
||||
|
||||
|
||||
class MiscRouteRegistrar:
|
||||
"""Bind miscellaneous route definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(
|
||||
self,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = MISC_ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
for definition in definitions:
|
||||
self._bind(definition.method, definition.path, handler_lookup[definition.handler_name])
|
||||
|
||||
def _bind(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
135
py/routes/misc_routes.py
Normal file
135
py/routes/misc_routes.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Route controller for miscellaneous endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from ..services.metadata_service import (
|
||||
get_metadata_archive_manager,
|
||||
get_metadata_provider,
|
||||
update_metadata_providers,
|
||||
)
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from .handlers.misc_handlers import (
|
||||
FileSystemHandler,
|
||||
HealthCheckHandler,
|
||||
LoraCodeHandler,
|
||||
MetadataArchiveHandler,
|
||||
MiscHandlerSet,
|
||||
ModelExampleFilesHandler,
|
||||
ModelLibraryHandler,
|
||||
NodeRegistry,
|
||||
NodeRegistryHandler,
|
||||
SettingsHandler,
|
||||
TrainedWordsHandler,
|
||||
UsageStatsHandler,
|
||||
build_service_registry_adapter,
|
||||
)
|
||||
from .misc_route_registrar import MiscRouteRegistrar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get(
|
||||
"HF_HUB_DISABLE_TELEMETRY", "0"
|
||||
) == "0"
|
||||
|
||||
|
||||
class MiscRoutes:
|
||||
"""Route controller that mirrors the model route architecture."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
settings_service=None,
|
||||
usage_stats_factory: Callable[[], UsageStats] = UsageStats,
|
||||
prompt_server: type[PromptServer] = PromptServer,
|
||||
service_registry_adapter=build_service_registry_adapter(),
|
||||
metadata_provider_factory=get_metadata_provider,
|
||||
metadata_archive_manager_factory=get_metadata_archive_manager,
|
||||
metadata_provider_updater=update_metadata_providers,
|
||||
downloader_factory=get_downloader,
|
||||
registrar_factory=MiscRouteRegistrar,
|
||||
handler_set_factory=MiscHandlerSet,
|
||||
node_registry: NodeRegistry | None = None,
|
||||
standalone_mode_flag: bool = standalone_mode,
|
||||
) -> None:
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
self._usage_stats_factory = usage_stats_factory
|
||||
self._prompt_server = prompt_server
|
||||
self._service_registry_adapter = service_registry_adapter
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
self._metadata_archive_manager_factory = metadata_archive_manager_factory
|
||||
self._metadata_provider_updater = metadata_provider_updater
|
||||
self._downloader_factory = downloader_factory
|
||||
self._registrar_factory = registrar_factory
|
||||
self._handler_set_factory = handler_set_factory
|
||||
self._node_registry = node_registry or NodeRegistry()
|
||||
self._standalone_mode = standalone_mode_flag
|
||||
|
||||
self._handler_mapping: Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]] | None = None
|
||||
|
||||
@staticmethod
|
||||
def setup_routes(app: web.Application) -> None:
|
||||
"""Entry point used by the application bootstrap."""
|
||||
controller = MiscRoutes()
|
||||
controller.bind(app)
|
||||
|
||||
def bind(self, app: web.Application) -> None:
|
||||
registrar = self._registrar_factory(app)
|
||||
registrar.register_routes(self._ensure_handler_mapping())
|
||||
|
||||
def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
def _create_handler_set(self) -> MiscHandlerSet:
|
||||
health = HealthCheckHandler()
|
||||
settings_handler = SettingsHandler(
|
||||
settings_service=self._settings,
|
||||
metadata_provider_updater=self._metadata_provider_updater,
|
||||
downloader_factory=self._downloader_factory,
|
||||
)
|
||||
usage_stats = UsageStatsHandler(usage_stats_factory=self._usage_stats_factory)
|
||||
lora_code = LoraCodeHandler(prompt_server=self._prompt_server)
|
||||
trained_words = TrainedWordsHandler()
|
||||
model_examples = ModelExampleFilesHandler()
|
||||
metadata_archive = MetadataArchiveHandler(
|
||||
metadata_archive_manager_factory=self._metadata_archive_manager_factory,
|
||||
settings_service=self._settings,
|
||||
metadata_provider_updater=self._metadata_provider_updater,
|
||||
)
|
||||
filesystem = FileSystemHandler()
|
||||
node_registry_handler = NodeRegistryHandler(
|
||||
node_registry=self._node_registry,
|
||||
prompt_server=self._prompt_server,
|
||||
standalone_mode=self._standalone_mode,
|
||||
)
|
||||
model_library = ModelLibraryHandler(
|
||||
service_registry=self._service_registry_adapter,
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
)
|
||||
|
||||
return self._handler_set_factory(
|
||||
health=health,
|
||||
settings=settings_handler,
|
||||
usage_stats=usage_stats,
|
||||
lora_code=lora_code,
|
||||
trained_words=trained_words,
|
||||
model_examples=model_examples,
|
||||
node_registry=node_registry_handler,
|
||||
model_library=model_library,
|
||||
metadata_archive=metadata_archive,
|
||||
filesystem=filesystem,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MiscRoutes"]
|
||||
107
py/routes/model_route_registrar.py
Normal file
107
py/routes/model_route_registrar.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Route registrar for model endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path_template: str
|
||||
handler_name: str
|
||||
|
||||
def build_path(self, prefix: str) -> str:
|
||||
return self.path_template.replace("{prefix}", prefix)
|
||||
|
||||
|
||||
COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/bulk-delete", "bulk_delete_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/verify-duplicates", "verify_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/move_model", "move_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"),
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
||||
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||
)
|
||||
|
||||
|
||||
class ModelRouteRegistrar:
|
||||
"""Bind declarative definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_common_routes(
|
||||
self,
|
||||
prefix: str,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
for definition in definitions:
|
||||
self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name])
|
||||
|
||||
def add_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
self._bind_route(method, path, handler)
|
||||
|
||||
def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None:
|
||||
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
25
py/routes/preview_routes.py
Normal file
25
py/routes/preview_routes.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Route controller for preview asset delivery."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .handlers.preview_handlers import PreviewHandler
|
||||
|
||||
|
||||
class PreviewRoutes:
|
||||
"""Register routes that expose preview assets."""
|
||||
|
||||
def __init__(self, *, handler: PreviewHandler | None = None) -> None:
|
||||
self._handler = handler or PreviewHandler()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application) -> None:
|
||||
controller = cls()
|
||||
controller.register(app)
|
||||
|
||||
def register(self, app: web.Application) -> None:
|
||||
app.router.add_get('/api/lm/previews', self._handler.serve_preview)
|
||||
|
||||
|
||||
__all__ = ["PreviewRoutes"]
|
||||
64
py/routes/recipe_route_registrar.py
Normal file
64
py/routes/recipe_route_registrar.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Route registrar for recipe endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a recipe HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
)
|
||||
|
||||
|
||||
class RecipeRouteRegistrar:
|
||||
"""Bind declarative recipe definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||
for definition in ROUTE_DEFINITIONS:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
File diff suppressed because it is too large
Load Diff
540
py/routes/stats_routes.py
Normal file
540
py/routes/stats_routes.py
Normal file
@@ -0,0 +1,540 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, Counter
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from ..config import config
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SettingsProxy:
|
||||
def __init__(self):
|
||||
self._manager = None
|
||||
|
||||
def _resolve(self):
|
||||
if self._manager is None:
|
||||
self._manager = get_settings_manager()
|
||||
return self._manager
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self._resolve().get(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._resolve(), item)
|
||||
|
||||
|
||||
settings = _SettingsProxy()
|
||||
|
||||
class StatsRoutes:
|
||||
"""Route handlers for Statistics page and API endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_scanner = None
|
||||
self.checkpoint_scanner = None
|
||||
self.embedding_scanner = None
|
||||
self.usage_stats = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Only initialize usage stats if we have valid paths configured
|
||||
try:
|
||||
self.usage_stats = UsageStats()
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not initialize usage statistics: {e}")
|
||||
self.usage_stats = None
|
||||
|
||||
async def handle_stats_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /statistics request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if scanners are initializing
|
||||
lora_initializing = (
|
||||
self.lora_scanner._cache is None or
|
||||
(hasattr(self.lora_scanner, 'is_initializing') and self.lora_scanner.is_initializing())
|
||||
)
|
||||
|
||||
checkpoint_initializing = (
|
||||
self.checkpoint_scanner._cache is None or
|
||||
(hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing)
|
||||
)
|
||||
|
||||
embedding_initializing = (
|
||||
self.embedding_scanner._cache is None or
|
||||
(hasattr(self.embedding_scanner, 'is_initializing') and self.embedding_scanner.is_initializing())
|
||||
)
|
||||
|
||||
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
||||
|
||||
# 获取用户语言设置
|
||||
settings_object = settings
|
||||
user_language = settings_object.get('language', 'en')
|
||||
settings_manager = settings_object if not isinstance(settings_object, _SettingsProxy) else settings_object._resolve()
|
||||
|
||||
# 设置服务端i18n语言
|
||||
server_i18n.set_locale(user_language)
|
||||
|
||||
# 为模板环境添加i18n过滤器
|
||||
if not hasattr(self.template_env, '_i18n_filter_added'):
|
||||
self.template_env.filters['t'] = server_i18n.create_template_filter()
|
||||
self.template_env._i18n_filter_added = True
|
||||
|
||||
template = self.template_env.get_template('statistics.html')
|
||||
rendered = template.render(
|
||||
is_initializing=is_initializing,
|
||||
settings=settings_manager,
|
||||
request=request,
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling statistics request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading statistics page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_collection_overview(self, request: web.Request) -> web.Response:
|
||||
"""Get collection overview statistics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get LoRA statistics
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
lora_count = len(lora_cache.raw_data)
|
||||
lora_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data)
|
||||
|
||||
# Get Checkpoint statistics
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
checkpoint_count = len(checkpoint_cache.raw_data)
|
||||
checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
|
||||
|
||||
# Get Embedding statistics
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
embedding_count = len(embedding_cache.raw_data)
|
||||
embedding_size = sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'total_models': lora_count + checkpoint_count + embedding_count,
|
||||
'lora_count': lora_count,
|
||||
'checkpoint_count': checkpoint_count,
|
||||
'embedding_count': embedding_count,
|
||||
'total_size': lora_size + checkpoint_size + embedding_size,
|
||||
'lora_size': lora_size,
|
||||
'checkpoint_size': checkpoint_size,
|
||||
'embedding_size': embedding_size,
|
||||
'total_generations': usage_data.get('total_executions', 0),
|
||||
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
|
||||
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})),
|
||||
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collection overview: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_usage_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get usage analytics data"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data for enrichment
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create hash to model mapping
|
||||
lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data}
|
||||
checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data}
|
||||
embedding_map = {emb['sha256']: emb for emb in embedding_cache.raw_data}
|
||||
|
||||
# Prepare top used models
|
||||
top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10)
|
||||
top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10)
|
||||
top_embeddings = self._get_top_used_models(usage_data.get('embeddings', {}), embedding_map, 10)
|
||||
|
||||
# Prepare usage timeline (last 30 days)
|
||||
timeline = self._get_usage_timeline(usage_data, 30)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'top_loras': top_loras,
|
||||
'top_checkpoints': top_checkpoints,
|
||||
'top_embeddings': top_embeddings,
|
||||
'usage_timeline': timeline,
|
||||
'total_executions': usage_data.get('total_executions', 0)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_base_model_distribution(self, request: web.Request) -> web.Response:
|
||||
"""Get base model distribution statistics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count by base model
|
||||
lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data)
|
||||
checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data)
|
||||
embedding_base_models = Counter(emb.get('base_model', 'Unknown') for emb in embedding_cache.raw_data)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': dict(lora_base_models),
|
||||
'checkpoints': dict(checkpoint_base_models),
|
||||
'embeddings': dict(embedding_base_models)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting base model distribution: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_tag_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get tag usage analytics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count tag frequencies
|
||||
all_tags = []
|
||||
for lora in lora_cache.raw_data:
|
||||
all_tags.extend(lora.get('tags', []))
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
all_tags.extend(cp.get('tags', []))
|
||||
for emb in embedding_cache.raw_data:
|
||||
all_tags.extend(emb.get('tags', []))
|
||||
|
||||
tag_counts = Counter(all_tags)
|
||||
|
||||
# Get top 50 tags
|
||||
top_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.most_common(50)]
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'top_tags': top_tags,
|
||||
'total_unique_tags': len(tag_counts)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tag analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_storage_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get storage usage analytics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create models with usage data
|
||||
lora_storage = []
|
||||
for lora in lora_cache.raw_data:
|
||||
usage_count = 0
|
||||
if lora['sha256'] in usage_data.get('loras', {}):
|
||||
usage_count = usage_data['loras'][lora['sha256']].get('total', 0)
|
||||
|
||||
lora_storage.append({
|
||||
'name': lora['model_name'],
|
||||
'size': lora.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': lora.get('folder', ''),
|
||||
'base_model': lora.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
checkpoint_storage = []
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
usage_count = 0
|
||||
if cp['sha256'] in usage_data.get('checkpoints', {}):
|
||||
usage_count = usage_data['checkpoints'][cp['sha256']].get('total', 0)
|
||||
|
||||
checkpoint_storage.append({
|
||||
'name': cp['model_name'],
|
||||
'size': cp.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': cp.get('folder', ''),
|
||||
'base_model': cp.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
embedding_storage = []
|
||||
for emb in embedding_cache.raw_data:
|
||||
usage_count = 0
|
||||
if emb['sha256'] in usage_data.get('embeddings', {}):
|
||||
usage_count = usage_data['embeddings'][emb['sha256']].get('total', 0)
|
||||
|
||||
embedding_storage.append({
|
||||
'name': emb['model_name'],
|
||||
'size': emb.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': emb.get('folder', ''),
|
||||
'base_model': emb.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
# Sort by size
|
||||
lora_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
checkpoint_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
embedding_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': lora_storage[:20], # Top 20 by size
|
||||
'checkpoints': checkpoint_storage[:20],
|
||||
'embeddings': embedding_storage[:20]
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_insights(self, request: web.Request) -> web.Response:
|
||||
"""Get smart insights about the collection"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
insights = []
|
||||
|
||||
# Calculate unused models
|
||||
unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {}))
|
||||
unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
|
||||
unused_embeddings = self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
|
||||
total_loras = len(lora_cache.raw_data)
|
||||
total_checkpoints = len(checkpoint_cache.raw_data)
|
||||
total_embeddings = len(embedding_cache.raw_data)
|
||||
|
||||
if total_loras > 0:
|
||||
unused_lora_percent = (unused_loras / total_loras) * 100
|
||||
if unused_lora_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused LoRAs',
|
||||
'description': f'{unused_lora_percent:.1f}% of your LoRAs ({unused_loras}/{total_loras}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused models to free up storage space.'
|
||||
})
|
||||
|
||||
if total_checkpoints > 0:
|
||||
unused_checkpoint_percent = (unused_checkpoints / total_checkpoints) * 100
|
||||
if unused_checkpoint_percent > 30:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'Unused Checkpoints Detected',
|
||||
'description': f'{unused_checkpoint_percent:.1f}% of your checkpoints ({unused_checkpoints}/{total_checkpoints}) have never been used.',
|
||||
'suggestion': 'Review and consider removing checkpoints you no longer need.'
|
||||
})
|
||||
|
||||
if total_embeddings > 0:
|
||||
unused_embedding_percent = (unused_embeddings / total_embeddings) * 100
|
||||
if unused_embedding_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused Embeddings',
|
||||
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
|
||||
})
|
||||
|
||||
# Storage insights
|
||||
total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \
|
||||
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) + \
|
||||
sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
if total_size > 100 * 1024 * 1024 * 1024: # 100GB
|
||||
insights.append({
|
||||
'type': 'info',
|
||||
'title': 'Large Collection Detected',
|
||||
'description': f'Your model collection is using {self._format_size(total_size)} of storage.',
|
||||
'suggestion': 'Consider using external storage or cloud solutions for better organization.'
|
||||
})
|
||||
|
||||
# Recent activity insight
|
||||
if usage_data.get('total_executions', 0) > 100:
|
||||
insights.append({
|
||||
'type': 'success',
|
||||
'title': 'Active User',
|
||||
'description': f'You\'ve completed {usage_data["total_executions"]} generations so far!',
|
||||
'suggestion': 'Keep exploring and creating amazing content with your models.'
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'insights': insights
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting insights: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
def _count_unused_models(self, models: List[Dict], usage_data: Dict) -> int:
|
||||
"""Count models that have never been used"""
|
||||
used_hashes = set(usage_data.keys())
|
||||
unused_count = 0
|
||||
|
||||
for model in models:
|
||||
if model.get('sha256') not in used_hashes:
|
||||
unused_count += 1
|
||||
|
||||
return unused_count
|
||||
|
||||
def _get_top_used_models(self, usage_data: Dict, model_map: Dict, limit: int) -> List[Dict]:
|
||||
"""Get top used models with their metadata"""
|
||||
sorted_usage = sorted(usage_data.items(), key=lambda x: x[1].get('total', 0), reverse=True)
|
||||
|
||||
top_models = []
|
||||
for sha256, usage_info in sorted_usage[:limit]:
|
||||
if sha256 in model_map:
|
||||
model = model_map[sha256]
|
||||
top_models.append({
|
||||
'name': model['model_name'],
|
||||
'usage_count': usage_info.get('total', 0),
|
||||
'base_model': model.get('base_model', 'Unknown'),
|
||||
'preview_url': config.get_preview_static_url(model.get('preview_url', '')),
|
||||
'folder': model.get('folder', '')
|
||||
})
|
||||
|
||||
return top_models
|
||||
|
||||
def _get_usage_timeline(self, usage_data: Dict, days: int) -> List[Dict]:
|
||||
"""Get usage timeline for the past N days"""
|
||||
timeline = []
|
||||
today = datetime.now()
|
||||
|
||||
for i in range(days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.strftime('%Y-%m-%d')
|
||||
|
||||
lora_usage = 0
|
||||
checkpoint_usage = 0
|
||||
embedding_usage = 0
|
||||
|
||||
# Count usage for this date
|
||||
for model_usage in usage_data.get('loras', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
lora_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('checkpoints', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
checkpoint_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('embeddings', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
embedding_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
timeline.append({
|
||||
'date': date_str,
|
||||
'lora_usage': lora_usage,
|
||||
'checkpoint_usage': checkpoint_usage,
|
||||
'embedding_usage': embedding_usage,
|
||||
'total_usage': lora_usage + checkpoint_usage + embedding_usage
|
||||
})
|
||||
|
||||
return list(reversed(timeline)) # Oldest to newest
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f} PB"
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
|
||||
# Register page route
|
||||
app.router.add_get('/statistics', self.handle_stats_page)
|
||||
|
||||
# Register API routes
|
||||
app.router.add_get('/api/lm/stats/collection-overview', self.get_collection_overview)
|
||||
app.router.add_get('/api/lm/stats/usage-analytics', self.get_usage_analytics)
|
||||
app.router.add_get('/api/lm/stats/base-model-distribution', self.get_base_model_distribution)
|
||||
app.router.add_get('/api/lm/stats/tag-analytics', self.get_tag_analytics)
|
||||
app.router.add_get('/api/lm/stats/storage-analytics', self.get_storage_analytics)
|
||||
app.router.add_get('/api/lm/stats/insights', self.get_insights)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
@@ -1,19 +1,31 @@
|
||||
import os
|
||||
import aiohttp
|
||||
import logging
|
||||
import toml
|
||||
from aiohttp import web
|
||||
from typing import Dict, Any, List
|
||||
import git
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
import asyncio
|
||||
from aiohttp import web, ClientError
|
||||
from typing import Dict, List
|
||||
|
||||
from ..utils.settings_paths import ensure_settings_file
|
||||
from ..services.downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NETWORK_EXCEPTIONS = (ClientError, OSError, asyncio.TimeoutError)
|
||||
|
||||
|
||||
class UpdateRoutes:
|
||||
"""Routes for handling plugin update checks"""
|
||||
|
||||
@staticmethod
|
||||
def setup_routes(app):
|
||||
"""Register update check routes"""
|
||||
app.router.add_get('/loras/api/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/lm/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/lm/version-info', UpdateRoutes.get_version_info)
|
||||
app.router.add_post('/api/lm/perform-update', UpdateRoutes.perform_update)
|
||||
|
||||
@staticmethod
|
||||
async def check_updates(request):
|
||||
@@ -22,32 +34,361 @@ class UpdateRoutes:
|
||||
Returns update status and version information
|
||||
"""
|
||||
try:
|
||||
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version()
|
||||
|
||||
# Get git info (commit hash, branch)
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
|
||||
# Fetch remote version from GitHub
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
if nightly:
|
||||
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||
else:
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
|
||||
# Compare versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
if nightly:
|
||||
# For nightly, compare commit hashes
|
||||
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||
else:
|
||||
# For stable, compare semantic versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'current_version': local_version,
|
||||
'latest_version': remote_version,
|
||||
'update_available': update_available,
|
||||
'changelog': changelog
|
||||
'changelog': changelog,
|
||||
'git_info': git_info,
|
||||
'nightly': nightly
|
||||
})
|
||||
|
||||
except NETWORK_EXCEPTIONS as e:
|
||||
logger.warning("Network unavailable during update check: %s", e)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Network unavailable for update check'
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check for updates: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def get_version_info(request):
|
||||
"""
|
||||
Returns the current version in the format 'version-short_hash'
|
||||
"""
|
||||
try:
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version().replace('v', '')
|
||||
|
||||
# Get git info (commit hash, branch)
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
short_hash = git_info['short_hash']
|
||||
|
||||
# Format: version-short_hash
|
||||
version_string = f"{local_version}-{short_hash}"
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'version': version_string
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get version info: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def perform_update(request):
|
||||
"""
|
||||
Perform Git-based update to latest release tag or main branch.
|
||||
If .git is missing, fallback to ZIP download.
|
||||
"""
|
||||
try:
|
||||
body = await request.json() if request.has_body else {}
|
||||
nightly = body.get('nightly', False)
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
settings_path = ensure_settings_file(logger)
|
||||
settings_backup = None
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings_backup = f.read()
|
||||
logger.info("Backed up settings.json")
|
||||
|
||||
git_folder = os.path.join(plugin_root, '.git')
|
||||
if os.path.exists(git_folder):
|
||||
# Git update
|
||||
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||
else:
|
||||
# Fallback: Download ZIP and replace files
|
||||
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
|
||||
|
||||
if settings_backup and success:
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
f.write(settings_backup)
|
||||
logger.info("Restored settings.json")
|
||||
|
||||
if success:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully updated to {new_version}',
|
||||
'new_version': new_version
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to complete update'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download latest release ZIP from GitHub and replace plugin files.
|
||||
Skips settings.json and civitai folder. Writes extracted file list to .tracking.
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Get release info
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
github_api,
|
||||
use_auth=False
|
||||
)
|
||||
if not success:
|
||||
logger.error(f"Failed to fetch release info: {data}")
|
||||
return False, ""
|
||||
|
||||
zip_url = data.get("zipball_url")
|
||||
version = data.get("tag_name", "unknown")
|
||||
|
||||
# Download ZIP to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
tmp_zip_path = tmp_zip.name
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
url=zip_url,
|
||||
save_path=tmp_zip_path,
|
||||
use_auth=False,
|
||||
allow_resume=False
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"Failed to download ZIP: {result}")
|
||||
return False, ""
|
||||
|
||||
zip_path = tmp_zip_path
|
||||
|
||||
# Skip both settings.json, civitai and model cache folder
|
||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai', 'model_cache'])
|
||||
|
||||
# Extract ZIP to temp dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
# Find extracted folder (GitHub ZIP contains a root folder)
|
||||
extracted_root = next(os.scandir(tmp_dir)).path
|
||||
|
||||
# Copy files, skipping settings.json and civitai folder
|
||||
for item in os.listdir(extracted_root):
|
||||
if item == 'settings.json' or item == 'civitai':
|
||||
continue
|
||||
src = os.path.join(extracted_root, item)
|
||||
dst = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(src):
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Write .tracking file: list all files under extracted_root, relative to extracted_root
|
||||
# for ComfyUI Manager to work properly
|
||||
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
||||
tracking_files = []
|
||||
for root, dirs, files in os.walk(extracted_root):
|
||||
# Skip civitai folder and its contents
|
||||
rel_root = os.path.relpath(root, extracted_root)
|
||||
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
|
||||
continue
|
||||
for file in files:
|
||||
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
||||
# Skip settings.json and any file under civitai
|
||||
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
|
||||
continue
|
||||
tracking_files.append(rel_path.replace("\\", "/"))
|
||||
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
||||
file.write('\n'.join(tracking_files))
|
||||
|
||||
os.remove(zip_path)
|
||||
logger.info(f"Updated plugin via ZIP to {version}")
|
||||
return True, version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
||||
return False, ""
|
||||
|
||||
def _clean_plugin_folder(plugin_root, skip_files=None):
|
||||
skip_files = skip_files or []
|
||||
for item in os.listdir(plugin_root):
|
||||
if item in skip_files:
|
||||
continue
|
||||
path = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
os.remove(path)
|
||||
|
||||
@staticmethod
|
||||
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
Fetch latest commit from main branch
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
|
||||
# Use GitHub API to fetch the latest commit from main branch
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch GitHub commit: {data}")
|
||||
return "main", []
|
||||
|
||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||
commit_message = data.get('commit', {}).get('message', '')
|
||||
|
||||
# Format as "main-{short_hash}"
|
||||
version = f"main-{commit_sha}"
|
||||
|
||||
# Use commit message as changelog
|
||||
changelog = [commit_message] if commit_message else []
|
||||
|
||||
return version, changelog
|
||||
|
||||
except NETWORK_EXCEPTIONS as e:
|
||||
logger.warning("Unable to reach GitHub for nightly version: %s", e)
|
||||
return "main", []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||
return "main", []
|
||||
|
||||
@staticmethod
|
||||
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||
"""
|
||||
Compare local commit hash with remote main branch
|
||||
"""
|
||||
try:
|
||||
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||
if local_hash == 'unknown':
|
||||
return True # Assume update available if we can't get local hash
|
||||
|
||||
# Extract remote hash from version string (format: "main-{hash}")
|
||||
if '-' in remote_version:
|
||||
remote_hash = remote_version.split('-')[-1]
|
||||
return local_hash != remote_hash
|
||||
|
||||
return True # Default to update available
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing nightly versions: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform Git-based update using GitPython
|
||||
|
||||
Args:
|
||||
plugin_root: Path to the plugin root directory
|
||||
nightly: Whether to update to main branch or latest release
|
||||
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
|
||||
# Fetch latest changes
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
|
||||
if nightly:
|
||||
# Reset to discard any local changes
|
||||
repo.git.reset('--hard')
|
||||
# Clean untracked files
|
||||
repo.git.clean('-fd')
|
||||
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
if main_branch not in [branch.name for branch in repo.branches]:
|
||||
# Create local main branch if it doesn't exist
|
||||
repo.create_head(main_branch, origin.refs.main)
|
||||
|
||||
repo.heads[main_branch].checkout()
|
||||
origin.pull(main_branch)
|
||||
|
||||
# Get new commit hash
|
||||
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||
|
||||
else:
|
||||
# Reset to discard any local changes
|
||||
repo.git.reset('--hard')
|
||||
# Clean untracked files
|
||||
repo.git.clean('-fd')
|
||||
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
if not tags:
|
||||
logger.error("No tags found in repository")
|
||||
return False, ""
|
||||
|
||||
latest_tag = tags[0]
|
||||
|
||||
# Checkout to latest tag
|
||||
repo.git.checkout(latest_tag.name)
|
||||
|
||||
new_version = latest_tag.name
|
||||
|
||||
logger.info(f"Successfully updated to {new_version}")
|
||||
return True, new_version
|
||||
|
||||
except git.exc.GitError as e:
|
||||
logger.error(f"Git error during update: {e}")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Git update: {e}")
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_local_version() -> str:
|
||||
@@ -72,6 +413,35 @@ class UpdateRoutes:
|
||||
logger.error(f"Failed to get local version: {e}", exc_info=True)
|
||||
return "v0.0.0"
|
||||
|
||||
@staticmethod
|
||||
def _get_git_info() -> Dict[str, str]:
|
||||
"""Get Git repository information"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
git_info = {
|
||||
'commit_hash': 'unknown',
|
||||
'short_hash': 'stable',
|
||||
'branch': 'unknown',
|
||||
'commit_date': 'unknown'
|
||||
}
|
||||
|
||||
try:
|
||||
# Check if we're in a git repository
|
||||
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
||||
return git_info
|
||||
|
||||
repo = git.Repo(plugin_root)
|
||||
commit = repo.head.commit
|
||||
git_info['commit_hash'] = commit.hexsha
|
||||
git_info['short_hash'] = commit.hexsha[:7]
|
||||
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
|
||||
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting git info: {e}")
|
||||
|
||||
return git_info
|
||||
|
||||
@staticmethod
|
||||
async def _get_remote_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
@@ -86,23 +456,26 @@ class UpdateRoutes:
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch GitHub release: {response.status}")
|
||||
return "v0.0.0", []
|
||||
|
||||
data = await response.json()
|
||||
version = data.get('tag_name', '')
|
||||
if not version.startswith('v'):
|
||||
version = f"v{version}"
|
||||
|
||||
# Extract changelog from release notes
|
||||
body = data.get('body', '')
|
||||
changelog = UpdateRoutes._parse_changelog(body)
|
||||
|
||||
return version, changelog
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch GitHub release: {data}")
|
||||
return "v0.0.0", []
|
||||
|
||||
version = data.get('tag_name', '')
|
||||
if not version.startswith('v'):
|
||||
version = f"v{version}"
|
||||
|
||||
# Extract changelog from release notes
|
||||
body = data.get('body', '')
|
||||
changelog = UpdateRoutes._parse_changelog(body)
|
||||
|
||||
return version, changelog
|
||||
|
||||
except NETWORK_EXCEPTIONS as e:
|
||||
logger.warning("Unable to reach GitHub for release info: %s", e)
|
||||
return "v0.0.0", []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
||||
return "v0.0.0", []
|
||||
@@ -150,11 +523,16 @@ class UpdateRoutes:
|
||||
"""
|
||||
Compare two semantic version strings
|
||||
Returns True if version2 is newer than version1
|
||||
Ignores any suffixes after '-' (e.g., -bugfix, -alpha)
|
||||
"""
|
||||
try:
|
||||
# Clean version strings - remove any suffix after '-'
|
||||
v1_clean = version1.split('-')[0]
|
||||
v2_clean = version2.split('-')[0]
|
||||
|
||||
# Split versions into components
|
||||
v1_parts = [int(x) for x in version1.split('.')]
|
||||
v2_parts = [int(x) for x in version2.split('.')]
|
||||
v1_parts = [int(x) for x in v1_clean.split('.')]
|
||||
v2_parts = [int(x) for x in v2_clean.split('.')]
|
||||
|
||||
# Ensure both have 3 components (major.minor.patch)
|
||||
while len(v1_parts) < 3:
|
||||
|
||||
736
py/services/base_model_service.py
Normal file
736
py/services/base_model_service.py
Normal file
@@ -0,0 +1,736 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_query import (
|
||||
FilterCriteria,
|
||||
ModelCacheRepository,
|
||||
ModelFilterSet,
|
||||
SearchStrategy,
|
||||
SettingsProvider,
|
||||
normalize_civitai_model_type,
|
||||
resolve_civitai_model_type,
|
||||
)
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model_update_service import ModelUpdateService
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
scanner,
|
||||
metadata_class: Type[BaseModelMetadata],
|
||||
*,
|
||||
cache_repository: Optional[ModelCacheRepository] = None,
|
||||
filter_set: Optional[ModelFilterSet] = None,
|
||||
search_strategy: Optional[SearchStrategy] = None,
|
||||
settings_provider: Optional[SettingsProvider] = None,
|
||||
update_service: Optional["ModelUpdateService"] = None,
|
||||
):
|
||||
"""Initialize the service.
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.).
|
||||
scanner: Model scanner instance.
|
||||
metadata_class: Metadata class for this model type.
|
||||
cache_repository: Custom repository for cache access (primarily for tests).
|
||||
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||
search_strategy: Search component for fuzzy/text matching.
|
||||
settings_provider: Settings object; defaults to the global settings manager.
|
||||
update_service: Service used to determine whether models have remote updates available.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
self.settings = settings_provider or get_settings_manager()
|
||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||
self.search_strategy = search_strategy or SearchStrategy()
|
||||
self.update_service = update_service
|
||||
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_by: str = 'name',
|
||||
folder: str = None,
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
update_available_only: bool = False,
|
||||
credit_required: Optional[bool] = None,
|
||||
allow_selling_generated_content: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated and filtered model data"""
|
||||
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||
else:
|
||||
filtered_data = await self._apply_common_filters(
|
||||
sorted_data,
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
)
|
||||
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data,
|
||||
search,
|
||||
fuzzy_search,
|
||||
search_options,
|
||||
)
|
||||
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
# Apply license-based filters
|
||||
if credit_required is not None:
|
||||
filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required)
|
||||
|
||||
if allow_selling_generated_content is not None:
|
||||
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
|
||||
|
||||
annotated_for_filter: Optional[List[Dict]] = None
|
||||
if update_available_only:
|
||||
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
||||
filtered_data = [
|
||||
item for item in annotated_for_filter
|
||||
if item.get('update_available')
|
||||
]
|
||||
|
||||
paginated = self._paginate(filtered_data, page, page_size)
|
||||
|
||||
if update_available_only:
|
||||
# Items already include update flags thanks to the pre-filter annotation.
|
||||
paginated['items'] = list(paginated['items'])
|
||||
else:
|
||||
paginated['items'] = await self._annotate_update_flags(
|
||||
paginated['items'],
|
||||
)
|
||||
return paginated
|
||||
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
criteria = FilterCriteria(
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
)
|
||||
return self.filter_set.apply(data, criteria)
|
||||
|
||||
async def _apply_search_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
search: str,
|
||||
fuzzy_search: bool,
|
||||
search_options: dict,
|
||||
) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]:
|
||||
"""Apply credit required filtering based on license_flags.
|
||||
|
||||
Args:
|
||||
data: List of model data items
|
||||
credit_required:
|
||||
- True: Return items where credit is required (allowNoCredit=False)
|
||||
- False: Return items where credit is not required (allowNoCredit=True)
|
||||
"""
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||
|
||||
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
|
||||
allow_no_credit = bool(license_flags & (1 << 0))
|
||||
|
||||
# If credit_required is True, we want items where allowNoCredit is False (credit required)
|
||||
# If credit_required is False, we want items where allowNoCredit is True (no credit required)
|
||||
if credit_required:
|
||||
if not allow_no_credit: # Credit is required
|
||||
filtered_data.append(item)
|
||||
else:
|
||||
if allow_no_credit: # Credit is not required
|
||||
filtered_data.append(item)
|
||||
|
||||
return filtered_data
|
||||
|
||||
async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]:
|
||||
"""Apply allow selling generated content filtering based on license_flags.
|
||||
|
||||
Args:
|
||||
data: List of model data items
|
||||
allow_selling:
|
||||
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
|
||||
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
|
||||
"""
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||
|
||||
# Bits 1-4 represent commercial use permissions
|
||||
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
|
||||
has_image_permission = bool(license_flags & (1 << 1))
|
||||
|
||||
# If allow_selling is True, we want items where Image permission is granted
|
||||
# If allow_selling is False, we want items where Image permission is not granted
|
||||
if allow_selling:
|
||||
if has_image_permission: # Selling generated content is allowed
|
||||
filtered_data.append(item)
|
||||
else:
|
||||
if not has_image_permission: # Selling generated content is not allowed
|
||||
filtered_data.append(item)
|
||||
|
||||
return filtered_data
|
||||
|
||||
async def _annotate_update_flags(
|
||||
self,
|
||||
items: List[Dict],
|
||||
) -> List[Dict]:
|
||||
"""Attach an update_available flag to each response item.
|
||||
|
||||
Items without a civitai model id default to False.
|
||||
"""
|
||||
if not items:
|
||||
return []
|
||||
|
||||
annotated = [dict(item) for item in items]
|
||||
|
||||
if self.update_service is None:
|
||||
for item in annotated:
|
||||
item['update_available'] = False
|
||||
return annotated
|
||||
|
||||
id_to_items: Dict[int, List[Dict]] = {}
|
||||
ordered_ids: List[int] = []
|
||||
for item in annotated:
|
||||
model_id = self._extract_model_id(item)
|
||||
if model_id is None:
|
||||
item['update_available'] = False
|
||||
continue
|
||||
if model_id not in id_to_items:
|
||||
id_to_items[model_id] = []
|
||||
ordered_ids.append(model_id)
|
||||
id_to_items[model_id].append(item)
|
||||
|
||||
if not ordered_ids:
|
||||
return annotated
|
||||
|
||||
strategy_value = self.settings.get("update_flag_strategy")
|
||||
if isinstance(strategy_value, str) and strategy_value.strip():
|
||||
strategy = strategy_value.strip().lower()
|
||||
else:
|
||||
strategy = "same_base"
|
||||
same_base_mode = strategy == "same_base"
|
||||
|
||||
records = None
|
||||
resolved: Optional[Dict[int, bool]] = None
|
||||
if same_base_mode:
|
||||
record_method = getattr(self.update_service, "get_records_bulk", None)
|
||||
if callable(record_method):
|
||||
try:
|
||||
records = await record_method(self.model_type, ordered_ids)
|
||||
resolved = {
|
||||
model_id: record.has_update()
|
||||
for model_id, record in records.items()
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to resolve update records in bulk for %s models (%s): %s",
|
||||
self.model_type,
|
||||
ordered_ids,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
records = None
|
||||
resolved = None
|
||||
|
||||
if resolved is None:
|
||||
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
||||
if callable(bulk_method):
|
||||
try:
|
||||
resolved = await bulk_method(self.model_type, ordered_ids)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to resolve update status in bulk for %s models (%s): %s",
|
||||
self.model_type,
|
||||
ordered_ids,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved = None
|
||||
|
||||
if resolved is None:
|
||||
tasks = [
|
||||
self.update_service.has_update(self.model_type, model_id)
|
||||
for model_id in ordered_ids
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
resolved = {}
|
||||
for model_id, result in zip(ordered_ids, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"Failed to resolve update status for model %s (%s): %s",
|
||||
model_id,
|
||||
self.model_type,
|
||||
result,
|
||||
)
|
||||
continue
|
||||
resolved[model_id] = bool(result)
|
||||
|
||||
for model_id, items_for_id in id_to_items.items():
|
||||
default_flag = bool(resolved.get(model_id, False)) if resolved else False
|
||||
record = records.get(model_id) if records else None
|
||||
base_highest_versions = (
|
||||
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {}
|
||||
)
|
||||
for item in items_for_id:
|
||||
if same_base_mode and record is not None:
|
||||
base_model = self._extract_base_model(item)
|
||||
normalized_base = self._normalize_base_model_name(base_model)
|
||||
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None
|
||||
if threshold_version is None:
|
||||
threshold_version = self._extract_version_id(item)
|
||||
flag = record.has_update_for_base(
|
||||
threshold_version,
|
||||
base_model,
|
||||
)
|
||||
else:
|
||||
flag = default_flag
|
||||
item['update_available'] = flag
|
||||
|
||||
return annotated
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_id(item: Dict) -> Optional[int]:
|
||||
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
return None
|
||||
try:
|
||||
value = civitai.get('modelId')
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_version_id(item: Dict) -> Optional[int]:
|
||||
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
return None
|
||||
value = civitai.get('id')
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_base_model(item: Dict) -> Optional[str]:
|
||||
value = item.get('base_model')
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
candidate = value.strip()
|
||||
else:
|
||||
try:
|
||||
candidate = str(value).strip()
|
||||
except Exception:
|
||||
return None
|
||||
return candidate if candidate else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_base_model_name(value: Optional[str]) -> Optional[str]:
|
||||
"""Return a lowercased, trimmed base model name for comparison."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
candidate = value.strip()
|
||||
else:
|
||||
try:
|
||||
candidate = str(value).strip()
|
||||
except Exception:
|
||||
return None
|
||||
return candidate.lower() if candidate else None
|
||||
|
||||
def _build_highest_local_versions_by_base(self, record) -> Dict[str, int]:
|
||||
"""Return the highest local version id known for each normalized base model."""
|
||||
|
||||
if record is None:
|
||||
return {}
|
||||
|
||||
highest_by_base: Dict[str, int] = {}
|
||||
for version in getattr(record, "versions", []):
|
||||
if not getattr(version, "is_in_library", False):
|
||||
continue
|
||||
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None))
|
||||
if normalized_base is None:
|
||||
continue
|
||||
version_id = getattr(version, "version_id", None)
|
||||
if version_id is None:
|
||||
continue
|
||||
current_max = highest_by_base.get(normalized_base)
|
||||
if current_max is None or version_id > current_max:
|
||||
highest_by_base[normalized_base] = version_id
|
||||
|
||||
return highest_by_base
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
|
||||
"""Get counts of normalized CivitAI model types present in the cache."""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
type_counts: Dict[str, int] = {}
|
||||
for entry in cache.raw_data:
|
||||
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
|
||||
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
|
||||
continue
|
||||
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
||||
|
||||
sorted_types = sorted(
|
||||
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
|
||||
key=lambda value: value["count"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_types[:limit]
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images", "customImages", "creator"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||
"""Get hierarchical folder tree for a specific model root"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build tree structure from folders
|
||||
tree = {}
|
||||
|
||||
for folder in cache.folders:
|
||||
# Check if this folder belongs to the specified model root
|
||||
folder_belongs_to_root = False
|
||||
for root in self.scanner.get_model_roots():
|
||||
if root == model_root:
|
||||
folder_belongs_to_root = True
|
||||
break
|
||||
|
||||
if not folder_belongs_to_root:
|
||||
continue
|
||||
|
||||
# Split folder path into components
|
||||
parts = folder.split('/') if folder else []
|
||||
current_level = tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return tree
|
||||
|
||||
async def get_unified_folder_tree(self) -> Dict:
|
||||
"""Get unified folder tree across all model roots"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build unified tree structure by analyzing all relative paths
|
||||
unified_tree = {}
|
||||
|
||||
# Get all model roots for path normalization
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for folder in cache.folders:
|
||||
if not folder: # Skip empty folders
|
||||
continue
|
||||
|
||||
# Find which root this folder belongs to by checking the actual file paths
|
||||
# This is a simplified approach - we'll use the folder as-is since it should already be relative
|
||||
relative_path = folder
|
||||
|
||||
# Split folder path into components
|
||||
parts = relative_path.split('/')
|
||||
current_level = unified_tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return unified_tree
|
||||
|
||||
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
return model.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
preview_url = model.get('preview_url')
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return '/loras_static/images/no-preview.png'
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
civitai_data = model.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||
"""Load full metadata for a single model.
|
||||
|
||||
Listing/search endpoints return lightweight cache entries; this method performs
|
||||
a lazy read of the on-disk metadata snapshot when callers need full detail.
|
||||
"""
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
|
||||
if should_skip or metadata is None:
|
||||
return None
|
||||
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
|
||||
|
||||
|
||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||
"""Return the stored modelDescription field for a model."""
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
|
||||
if should_skip or metadata is None:
|
||||
return None
|
||||
return metadata.modelDescription or ''
|
||||
|
||||
@staticmethod
|
||||
def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]:
|
||||
"""Split a search string into include and exclude tokens."""
|
||||
include_terms: List[str] = []
|
||||
exclude_terms: List[str] = []
|
||||
|
||||
for raw_term in search_term.split():
|
||||
term = raw_term.strip()
|
||||
if not term:
|
||||
continue
|
||||
|
||||
if term.startswith("-") and len(term) > 1:
|
||||
exclude_terms.append(term[1:].lower())
|
||||
else:
|
||||
include_terms.append(term.lower())
|
||||
|
||||
return include_terms, exclude_terms
|
||||
|
||||
@staticmethod
|
||||
def _relative_path_matches_tokens(
|
||||
path_lower: str, include_terms: List[str], exclude_terms: List[str]
|
||||
) -> bool:
|
||||
"""Determine whether a relative path string satisfies include/exclude tokens."""
|
||||
if any(term and term in path_lower for term in exclude_terms):
|
||||
return False
|
||||
|
||||
for term in include_terms:
|
||||
if term and term not in path_lower:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
|
||||
"""Sort paths by how well they satisfy the include tokens."""
|
||||
path_lower = relative_path.lower()
|
||||
prefix_hits = sum(1 for term in include_terms if term and path_lower.startswith(term))
|
||||
match_positions = [path_lower.find(term) for term in include_terms if term and term in path_lower]
|
||||
first_match_index = min(match_positions) if match_positions else 0
|
||||
|
||||
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
|
||||
|
||||
|
||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
include_terms, exclude_terms = self._parse_search_tokens(search_term)
|
||||
|
||||
matching_paths = []
|
||||
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get('file_path', '')
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Calculate relative path from model root
|
||||
relative_path = None
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root)
|
||||
normalized_file = os.path.normpath(file_path)
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
# Remove root and leading separator to get relative path
|
||||
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
|
||||
break
|
||||
|
||||
if not relative_path:
|
||||
continue
|
||||
|
||||
relative_lower = relative_path.lower()
|
||||
if self._relative_path_matches_tokens(relative_lower, include_terms, exclude_terms):
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
# Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
|
||||
matching_paths.sort(
|
||||
key=lambda relative: self._relative_path_sort_key(relative, include_terms)
|
||||
)
|
||||
|
||||
return matching_paths[:limit]
|
||||
@@ -1,131 +1,53 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set
|
||||
import folder_paths # type: ignore
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
self._checkpoint_roots = self._init_checkpoint_roots()
|
||||
self._initialized = True
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||
return "checkpoint"
|
||||
|
||||
if config.unet_roots and root_path in config.unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
if hasattr(metadata, "model_type"):
|
||||
model_type = self._resolve_model_type(root_path)
|
||||
if model_type:
|
||||
metadata.model_type = model_type
|
||||
return metadata
|
||||
|
||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model_type = self._resolve_model_type(
|
||||
self._find_root_for_file(entry.get("file_path"))
|
||||
)
|
||||
if model_type:
|
||||
entry["model_type"] = model_type
|
||||
return entry
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _init_checkpoint_roots(self) -> List[str]:
|
||||
"""Initialize checkpoint roots from ComfyUI settings"""
|
||||
# Get both checkpoint and diffusion_models paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
|
||||
|
||||
# Combine, normalize and deduplicate paths
|
||||
all_paths = set()
|
||||
for path in checkpoint_paths + diffusion_paths:
|
||||
if os.path.exists(path):
|
||||
norm_path = path.replace(os.sep, "/")
|
||||
all_paths.add(norm_path)
|
||||
|
||||
# Sort for consistent order
|
||||
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
|
||||
|
||||
return sorted_paths
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return self._checkpoint_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all checkpoint directories and return metadata"""
|
||||
all_checkpoints = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for root in self._checkpoint_roots:
|
||||
task = asyncio.create_task(self._scan_directory(root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
checkpoints = await task
|
||||
all_checkpoints.extend(checkpoints)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
||||
|
||||
return all_checkpoints
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a directory for checkpoint files"""
|
||||
checkpoints = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
# Check if file has supported extension
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, checkpoints)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return checkpoints
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
||||
"""Process a single checkpoint file and add to results"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
checkpoints.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
return config.base_models_roots
|
||||
|
||||
51
py/services/checkpoint_service.py
Normal file
51
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner, update_service=None):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
update_service: Optional service for remote update tracking.
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"update_available": bool(checkpoint_data.get("update_available", False)),
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
431
py/services/civarchive_client.py
Normal file
431
py/services/civarchive_client.py
Normal file
@@ -0,0 +1,431 @@
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from .model_metadata_provider import CivArchiveModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CivArchiveClient:
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of CivArchiveClient"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
|
||||
# Register this client as a metadata provider
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
provider_manager.register_provider('civarchive', CivArchiveModelMetadataProvider(cls._instance), False)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self.base_url = "https://civarchive.com/api"
|
||||
|
||||
async def _request_json(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Call CivArchive API and return JSON payload"""
|
||||
success, payload = await self._make_request(path, params=params)
|
||||
if not success:
|
||||
error = payload if isinstance(payload, str) else "Request failed"
|
||||
return None, error
|
||||
if not isinstance(payload, dict):
|
||||
return None, "Invalid response structure"
|
||||
return payload, None
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[bool, Dict | str]:
|
||||
"""Wrapper around downloader.make_request that surfaces rate limits."""
|
||||
|
||||
downloader = await get_downloader()
|
||||
kwargs: Dict[str, Dict[str, str]] = {}
|
||||
if params:
|
||||
safe_params = {str(key): str(value) for key, value in params.items() if value is not None}
|
||||
if safe_params:
|
||||
kwargs["params"] = safe_params
|
||||
|
||||
success, payload = await downloader.make_request(
|
||||
"GET",
|
||||
f"{self.base_url}{path}",
|
||||
use_auth=False,
|
||||
**kwargs,
|
||||
)
|
||||
if not success and isinstance(payload, RateLimitError):
|
||||
if payload.provider is None:
|
||||
payload.provider = "civarchive_api"
|
||||
raise payload
|
||||
return success, payload
|
||||
|
||||
@staticmethod
|
||||
def _normalize_payload(payload: Dict) -> Dict:
|
||||
"""Unwrap CivArchive responses that wrap content under a data key"""
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _split_context(payload: Dict) -> Tuple[Dict, Dict, List[Dict]]:
|
||||
"""Separate version payload from surrounding model context"""
|
||||
data = CivArchiveClient._normalize_payload(payload)
|
||||
context: Dict = {}
|
||||
fallback_files: List[Dict] = []
|
||||
version: Dict = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in {"version", "model"}:
|
||||
continue
|
||||
context[key] = value
|
||||
|
||||
if isinstance(data.get("version"), dict):
|
||||
version = data["version"]
|
||||
|
||||
model_block = data.get("model")
|
||||
if isinstance(model_block, dict):
|
||||
for key, value in model_block.items():
|
||||
if key == "version":
|
||||
if not version and isinstance(value, dict):
|
||||
version = value
|
||||
continue
|
||||
context.setdefault(key, value)
|
||||
fallback_files = fallback_files or model_block.get("files") or []
|
||||
|
||||
fallback_files = fallback_files or data.get("files") or []
|
||||
return context, version, fallback_files
|
||||
|
||||
@staticmethod
|
||||
def _ensure_list(value) -> List:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if value is None:
|
||||
return []
|
||||
return [value]
|
||||
|
||||
@staticmethod
|
||||
def _build_model_info(context: Dict) -> Dict:
|
||||
tags = context.get("tags")
|
||||
if not isinstance(tags, list):
|
||||
tags = list(tags) if isinstance(tags, (set, tuple)) else ([] if tags is None else [tags])
|
||||
return {
|
||||
"name": context.get("name"),
|
||||
"type": context.get("type"),
|
||||
"nsfw": bool(context.get("is_nsfw", context.get("nsfw", False))),
|
||||
"description": context.get("description"),
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_creator_info(context: Dict) -> Dict:
|
||||
username = context.get("creator_username") or context.get("username") or ""
|
||||
image = context.get("creator_image") or context.get("creator_avatar") or ""
|
||||
creator: Dict[str, Optional[str]] = {
|
||||
"username": username,
|
||||
"image": image,
|
||||
}
|
||||
if context.get("creator_name"):
|
||||
creator["name"] = context["creator_name"]
|
||||
if context.get("creator_url"):
|
||||
creator["url"] = context["creator_url"]
|
||||
return creator
|
||||
|
||||
@staticmethod
|
||||
def _transform_file_entry(file_data: Dict) -> Dict:
|
||||
mirrors = file_data.get("mirrors") or []
|
||||
if not isinstance(mirrors, list):
|
||||
mirrors = [mirrors]
|
||||
available_mirror = next(
|
||||
(mirror for mirror in mirrors if isinstance(mirror, dict) and mirror.get("deletedAt") is None),
|
||||
None
|
||||
)
|
||||
download_url = file_data.get("downloadUrl")
|
||||
if not download_url and available_mirror:
|
||||
download_url = available_mirror.get("url")
|
||||
name = file_data.get("name")
|
||||
if not name and available_mirror:
|
||||
name = available_mirror.get("filename")
|
||||
|
||||
transformed: Dict = {
|
||||
"id": file_data.get("id"),
|
||||
"sizeKB": file_data.get("sizeKB"),
|
||||
"name": name,
|
||||
"type": file_data.get("type"),
|
||||
"downloadUrl": download_url,
|
||||
"primary": True,
|
||||
# TODO: for some reason is_primary is false in CivArchive response, need to figure this out,
|
||||
# "primary": bool(file_data.get("is_primary", file_data.get("primary", False))),
|
||||
"mirrors": mirrors,
|
||||
}
|
||||
|
||||
sha256 = file_data.get("sha256")
|
||||
if sha256:
|
||||
transformed["hashes"] = {"SHA256": str(sha256).upper()}
|
||||
elif isinstance(file_data.get("hashes"), dict):
|
||||
transformed["hashes"] = file_data["hashes"]
|
||||
|
||||
if "metadata" in file_data:
|
||||
transformed["metadata"] = file_data["metadata"]
|
||||
|
||||
if file_data.get("modelVersionId") is not None:
|
||||
transformed["modelVersionId"] = file_data.get("modelVersionId")
|
||||
elif file_data.get("model_version_id") is not None:
|
||||
transformed["modelVersionId"] = file_data.get("model_version_id")
|
||||
|
||||
if file_data.get("modelId") is not None:
|
||||
transformed["modelId"] = file_data.get("modelId")
|
||||
elif file_data.get("model_id") is not None:
|
||||
transformed["modelId"] = file_data.get("model_id")
|
||||
|
||||
return transformed
|
||||
|
||||
def _transform_files(
|
||||
self,
|
||||
files: Optional[List[Dict]],
|
||||
fallback_files: Optional[List[Dict]] = None
|
||||
) -> List[Dict]:
|
||||
candidates: List[Dict] = []
|
||||
if isinstance(files, list) and files:
|
||||
candidates = files
|
||||
elif isinstance(fallback_files, list):
|
||||
candidates = fallback_files
|
||||
|
||||
transformed_files: List[Dict] = []
|
||||
for file_data in candidates:
|
||||
if isinstance(file_data, dict):
|
||||
transformed_files.append(self._transform_file_entry(file_data))
|
||||
return transformed_files
|
||||
|
||||
def _transform_version(
|
||||
self,
|
||||
context: Dict,
|
||||
version: Dict,
|
||||
fallback_files: Optional[List[Dict]] = None
|
||||
) -> Optional[Dict]:
|
||||
if not version:
|
||||
return None
|
||||
|
||||
version_copy = deepcopy(version)
|
||||
version_copy.pop("model", None)
|
||||
version_copy.pop("creator", None)
|
||||
|
||||
if "trigger" in version_copy:
|
||||
triggers = version_copy.pop("trigger")
|
||||
if isinstance(triggers, list):
|
||||
version_copy["trainedWords"] = triggers
|
||||
elif triggers is None:
|
||||
version_copy["trainedWords"] = []
|
||||
else:
|
||||
version_copy["trainedWords"] = [triggers]
|
||||
|
||||
if "trainedWords" in version_copy and isinstance(version_copy["trainedWords"], str):
|
||||
version_copy["trainedWords"] = [version_copy["trainedWords"]]
|
||||
|
||||
if "nsfw_level" in version_copy:
|
||||
version_copy["nsfwLevel"] = version_copy.pop("nsfw_level")
|
||||
elif "nsfwLevel" not in version_copy and context.get("nsfw_level") is not None:
|
||||
version_copy["nsfwLevel"] = context.get("nsfw_level")
|
||||
|
||||
stats_keys = ["downloadCount", "ratingCount", "rating"]
|
||||
stats = {key: version_copy.pop(key) for key in stats_keys if key in version_copy}
|
||||
if stats:
|
||||
version_copy["stats"] = stats
|
||||
|
||||
version_copy["files"] = self._transform_files(version_copy.get("files"), fallback_files)
|
||||
version_copy["images"] = self._ensure_list(version_copy.get("images"))
|
||||
|
||||
version_copy["model"] = self._build_model_info(context)
|
||||
version_copy["creator"] = self._build_creator_info(context)
|
||||
|
||||
version_copy["source"] = "civarchive"
|
||||
version_copy["is_deleted"] = bool(context.get("deletedAt")) or bool(version.get("deletedAt"))
|
||||
|
||||
return version_copy
|
||||
|
||||
async def _resolve_version_from_files(self, payload: Dict) -> Optional[Dict]:
|
||||
"""Fallback to fetch version data when only file metadata is available"""
|
||||
data = self._normalize_payload(payload)
|
||||
files = data.get("files") or payload.get("files") or []
|
||||
if not isinstance(files, list):
|
||||
files = [files]
|
||||
for file_data in files:
|
||||
if not isinstance(file_data, dict):
|
||||
continue
|
||||
model_id = file_data.get("model_id") or file_data.get("modelId")
|
||||
version_id = file_data.get("model_version_id") or file_data.get("modelVersionId")
|
||||
if model_id is None or version_id is None:
|
||||
continue
|
||||
resolved = await self.get_model_version(model_id, version_id)
|
||||
if resolved:
|
||||
return resolved
|
||||
return None
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by SHA256 hash value using CivArchive API"""
|
||||
try:
|
||||
payload, error = await self._request_json(f"/sha256/{model_hash.lower()}")
|
||||
if error:
|
||||
if "not found" in error.lower():
|
||||
return None, "Model not found"
|
||||
return None, error
|
||||
|
||||
context, version_data, fallback_files = self._split_context(payload)
|
||||
transformed = self._transform_version(context, version_data, fallback_files)
|
||||
if transformed:
|
||||
return transformed, None
|
||||
|
||||
resolved = await self._resolve_version_from_files(payload)
|
||||
if resolved:
|
||||
return resolved, None
|
||||
|
||||
logger.error("Error fetching version of CivArchive model by hash %s", model_hash[:10])
|
||||
return None, "No version data found"
|
||||
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model by hash {model_hash[:10]}: {e}")
|
||||
return None, str(e)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model using CivArchive API"""
|
||||
try:
|
||||
payload, error = await self._request_json(f"/models/{model_id}")
|
||||
if error or payload is None:
|
||||
if error and "not found" in error.lower():
|
||||
return None
|
||||
logger.error(f"Error fetching CivArchive model versions for {model_id}: {error}")
|
||||
return None
|
||||
|
||||
data = self._normalize_payload(payload)
|
||||
context, version_data, fallback_files = self._split_context(payload)
|
||||
|
||||
versions_meta = data.get("versions") or []
|
||||
transformed_versions: List[Dict] = []
|
||||
for meta in versions_meta:
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
version_id = meta.get("id")
|
||||
if version_id is None:
|
||||
continue
|
||||
target_model_id = meta.get("modelId") or model_id
|
||||
version = await self.get_model_version(target_model_id, version_id)
|
||||
if version:
|
||||
transformed_versions.append(version)
|
||||
|
||||
# Ensure the primary version is included even if versions list was empty
|
||||
primary_version = self._transform_version(context, version_data, fallback_files)
|
||||
if primary_version:
|
||||
transformed_versions.insert(0, primary_version)
|
||||
|
||||
ordered_versions: List[Dict] = []
|
||||
seen_ids = set()
|
||||
for version in transformed_versions:
|
||||
version_id = version.get("id")
|
||||
if version_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(version_id)
|
||||
ordered_versions.append(version)
|
||||
|
||||
return {
|
||||
"modelVersions": ordered_versions,
|
||||
"type": context.get("type", ""),
|
||||
"name": context.get("name", ""),
|
||||
}
|
||||
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model versions for {model_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version using CivArchive API
|
||||
|
||||
Args:
|
||||
model_id: The model ID (required)
|
||||
version_id: Optional specific version ID to filter to
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The model version data or None if not found
|
||||
"""
|
||||
if model_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
params = {"modelVersionId": version_id} if version_id is not None else None
|
||||
payload, error = await self._request_json(f"/models/{model_id}", params=params)
|
||||
if error or payload is None:
|
||||
if error and "not found" in error.lower():
|
||||
return None
|
||||
logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {error}")
|
||||
return None
|
||||
|
||||
context, version_data, fallback_files = self._split_context(payload)
|
||||
|
||||
if not version_data:
|
||||
return await self._resolve_version_from_files(payload)
|
||||
|
||||
if version_id is not None:
|
||||
raw_id = version_data.get("id")
|
||||
if raw_id != version_id:
|
||||
logger.warning(
|
||||
"Requested version %s doesn't match default version %s for model %s",
|
||||
version_id,
|
||||
raw_id,
|
||||
model_id,
|
||||
)
|
||||
return None
|
||||
actual_model_id = version_data.get("modelId")
|
||||
context_model_id = context.get("id")
|
||||
# CivArchive can respond with data for a different model id while already
|
||||
# returning the fully resolved model context. Only follow the redirect when
|
||||
# the context itself still points to the original (wrong) model.
|
||||
if (
|
||||
actual_model_id is not None
|
||||
and str(actual_model_id) != str(model_id)
|
||||
and (context_model_id is None or str(context_model_id) != str(actual_model_id))
|
||||
):
|
||||
return await self.get_model_version(actual_model_id, version_id)
|
||||
|
||||
return self._transform_version(context, version_data, fallback_files)
|
||||
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
""" Fetch model version metadata using a known bogus model lookup
|
||||
CivArchive lacks a direct version lookup API, this uses a workaround (which we handle in the main model request now)
|
||||
|
||||
Args:
|
||||
version_id: The model version ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], Optional[str]]: (version_data, error_message)
|
||||
"""
|
||||
version = await self.get_model_version(1, version_id)
|
||||
if version is None:
|
||||
return None, "Model not found"
|
||||
return version, None
|
||||
@@ -1,13 +1,12 @@
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from email.parser import Parser
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from urllib.parse import unquote
|
||||
from ..utils.models import LoraMetadata
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
from .errors import RateLimitError, ResourceNotFoundError
|
||||
from ..utils.civitai_utils import resolve_license_payload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +20,11 @@ class CivitaiClient:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
|
||||
# Register this client as a metadata provider
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
@@ -30,69 +34,50 @@ class CivitaiClient:
|
||||
self._initialized = True
|
||||
|
||||
self.base_url = "https://civitai.com/api/v1"
|
||||
self.headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||
}
|
||||
self._session = None
|
||||
# Set default buffer size to 1MB for higher throughput
|
||||
self.chunk_size = 1024 * 1024
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
use_auth: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[bool, Dict | str]:
|
||||
"""Wrapper around downloader.make_request that surfaces rate limits."""
|
||||
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
method,
|
||||
url,
|
||||
use_auth=use_auth,
|
||||
**kwargs,
|
||||
)
|
||||
if not success and isinstance(result, RateLimitError):
|
||||
if result.provider is None:
|
||||
result.provider = "civitai_api"
|
||||
raise result
|
||||
return success, result
|
||||
|
||||
@staticmethod
|
||||
def _remove_comfy_metadata(model_version: Optional[Dict]) -> None:
|
||||
"""Remove Comfy-specific metadata from model version images."""
|
||||
if not isinstance(model_version, dict):
|
||||
return
|
||||
|
||||
images = model_version.get("images")
|
||||
if not isinstance(images, list):
|
||||
return
|
||||
|
||||
for image in images:
|
||||
if not isinstance(image, dict):
|
||||
continue
|
||||
|
||||
meta = image.get("meta")
|
||||
if isinstance(meta, dict) and "comfy" in meta:
|
||||
meta.pop("comfy", None)
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Lazy initialize the session"""
|
||||
if self._session is None:
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=10, # Increase parallel connections
|
||||
ttl_dns_cache=300, # DNS cache time
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
trust_env = True # Allow using system environment proxy settings
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=trust_env,
|
||||
timeout=timeout
|
||||
)
|
||||
return self._session
|
||||
|
||||
def _parse_content_disposition(self, header: str) -> str:
|
||||
"""Parse filename from content-disposition header"""
|
||||
if not header:
|
||||
return None
|
||||
|
||||
# Handle quoted filenames
|
||||
if 'filename="' in header:
|
||||
start = header.index('filename="') + 10
|
||||
end = header.index('"', start)
|
||||
return unquote(header[start:end])
|
||||
|
||||
# Fallback to original parsing
|
||||
disposition = Parser().parsestr(f'Content-Disposition: {header}')
|
||||
filename = disposition.get_param('filename')
|
||||
if filename:
|
||||
return unquote(filename)
|
||||
return None
|
||||
|
||||
def _get_request_headers(self) -> dict:
|
||||
"""Get request headers with optional API key"""
|
||||
headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
from .settings_manager import settings
|
||||
api_key = settings.get('civitai_api_key')
|
||||
if (api_key):
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
return headers
|
||||
|
||||
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with content-disposition support and progress tracking
|
||||
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
@@ -103,200 +88,443 @@ class CivitaiClient:
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
session = await self.session
|
||||
downloader = await get_downloader()
|
||||
save_path = os.path.join(save_dir, default_filename)
|
||||
|
||||
# Use unified downloader with CivitAI authentication
|
||||
success, result = await downloader.download_file(
|
||||
url=url,
|
||||
save_path=save_path,
|
||||
progress_callback=progress_callback,
|
||||
use_auth=True, # Enable CivitAI authentication
|
||||
allow_resume=True
|
||||
)
|
||||
|
||||
return success, result
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
try:
|
||||
headers = self._get_request_headers()
|
||||
|
||||
# Add Range header to allow resumable downloads
|
||||
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
||||
|
||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
||||
if response.status != 200:
|
||||
# Handle 401 unauthorized responses
|
||||
if response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
|
||||
return False, "Invalid or missing CivitAI API key, or early access restriction."
|
||||
|
||||
# Handle other client errors that might be permission-related
|
||||
if response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
|
||||
# Generic error response for other status codes
|
||||
return False, f"Download failed with status {response.status}"
|
||||
success, version = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
message = str(version)
|
||||
if "not found" in message.lower():
|
||||
return None, "Model not found"
|
||||
|
||||
# Get filename from content-disposition header
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
filename = self._parse_content_disposition(content_disposition)
|
||||
if not filename:
|
||||
filename = default_filename
|
||||
|
||||
save_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Get total file size for progress calculation
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
current_size = 0
|
||||
last_progress_report_time = datetime.now()
|
||||
logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message)
|
||||
return None, message
|
||||
|
||||
# Stream download to file with progress updates using larger buffer
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 0.5:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return True, save_path
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Network error during download: {e}")
|
||||
return False, f"Network error: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Download error: {e}")
|
||||
return False, str(e)
|
||||
model_id = version.get('modelId')
|
||||
if model_id:
|
||||
model_data = await self._fetch_model_data(model_id)
|
||||
if model_data:
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
try:
|
||||
session = await self.session
|
||||
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"API Error: {str(e)}")
|
||||
return None
|
||||
self._remove_comfy_metadata(version)
|
||||
return version, None
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("API Error: %s", exc)
|
||||
return None, str(exc)
|
||||
|
||||
async def download_preview_image(self, image_url: str, save_path: str):
|
||||
try:
|
||||
session = await self.session
|
||||
async with session.get(image_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
image_url,
|
||||
use_auth=False # Preview images don't need auth
|
||||
)
|
||||
if success:
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Download Error: {str(e)}")
|
||||
logger.error(f"Download Error: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
||||
@staticmethod
|
||||
def _extract_error_message(payload: Any) -> str:
|
||||
"""Return a human-readable error message from an API payload."""
|
||||
|
||||
def _from_value(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
for key in ("message", "error", "detail", "details"):
|
||||
if key in value:
|
||||
candidate = _from_value(value[key])
|
||||
if candidate:
|
||||
return candidate
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
candidate = _from_value(item)
|
||||
if candidate:
|
||||
return candidate
|
||||
return ""
|
||||
|
||||
return _from_value(payload)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model with local availability info"""
|
||||
try:
|
||||
session = await self.session # 等待获取 session
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
data = await response.json()
|
||||
success, result = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Also return model type along with versions
|
||||
return {
|
||||
'modelVersions': data.get('modelVersions', []),
|
||||
'type': data.get('type', '')
|
||||
'modelVersions': result.get('modelVersions', []),
|
||||
'type': result.get('type', ''),
|
||||
'name': result.get('name', '')
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
message = self._extract_error_message(result)
|
||||
if message and 'not found' in message.lower():
|
||||
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
||||
if message:
|
||||
raise RuntimeError(message)
|
||||
return None
|
||||
except RateLimitError:
|
||||
raise
|
||||
except ResourceNotFoundError as exc:
|
||||
logger.info("Model %s is no longer available on Civitai: %s", model_id, exc)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error fetching model versions: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self, model_ids: Sequence[int]
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
"""Fetch model metadata for multiple ids using the batch API."""
|
||||
|
||||
deduped: Dict[int, None] = {}
|
||||
for raw_id in model_ids:
|
||||
try:
|
||||
normalized = int(raw_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
deduped.setdefault(normalized, None)
|
||||
|
||||
normalized_ids = [str(model_id) for model_id in deduped.keys()]
|
||||
if not normalized_ids:
|
||||
return {}
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
|
||||
"""Fetch model version metadata from Civitai"""
|
||||
try:
|
||||
session = await self.session
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
headers = self._get_request_headers()
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
query = ",".join(normalized_ids)
|
||||
success, result = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models",
|
||||
use_auth=True,
|
||||
params={'ids': query},
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version info: {e}")
|
||||
|
||||
items = result.get('items') if isinstance(result, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return {}
|
||||
|
||||
payload: Dict[int, Dict] = {}
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
model_id = item.get('id')
|
||||
try:
|
||||
normalized_id = int(model_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
payload[normalized_id] = {
|
||||
'modelVersions': item.get('modelVersions', []),
|
||||
'type': item.get('type', ''),
|
||||
'name': item.get('name', ''),
|
||||
'allowNoCredit': item.get('allowNoCredit'),
|
||||
'allowCommercialUse': item.get('allowCommercialUse'),
|
||||
'allowDerivatives': item.get('allowDerivatives'),
|
||||
'allowDifferentLicense': item.get('allowDifferentLicense'),
|
||||
}
|
||||
return payload
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error fetching model versions in bulk: {exc}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata."""
|
||||
try:
|
||||
if model_id is None and version_id is not None:
|
||||
return await self._get_version_by_id_only(version_id)
|
||||
|
||||
if model_id is not None:
|
||||
return await self._get_version_with_model_id(model_id, version_id)
|
||||
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description and tags) from Civitai API
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
return None
|
||||
|
||||
async def _get_version_by_id_only(self, version_id: int) -> Optional[Dict]:
|
||||
version = await self._fetch_version_by_id(version_id)
|
||||
if version is None:
|
||||
return None
|
||||
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
model_data = await self._fetch_model_data(model_id)
|
||||
if model_data:
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||
model_data = await self._fetch_model_data(model_id)
|
||||
if not model_data:
|
||||
return None
|
||||
|
||||
target_version = self._select_target_version(model_data, model_id, version_id)
|
||||
if target_version is None:
|
||||
return None
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None
|
||||
|
||||
if version is None:
|
||||
model_hash = self._extract_primary_model_hash(target_version)
|
||||
if model_hash:
|
||||
version = await self._fetch_version_by_hash(model_hash)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = self._build_version_from_model_data(target_version, model_id, model_data)
|
||||
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
|
||||
success, data = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return data
|
||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||
return None
|
||||
|
||||
async def _fetch_version_by_id(self, version_id: Optional[int]) -> Optional[Dict]:
|
||||
if version_id is None:
|
||||
return None
|
||||
|
||||
success, version = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
|
||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||
return None
|
||||
|
||||
async def _fetch_version_by_hash(self, model_hash: Optional[str]) -> Optional[Dict]:
|
||||
if not model_hash:
|
||||
return None
|
||||
|
||||
success, version = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
|
||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||
return None
|
||||
|
||||
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||
model_versions = model_data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
return model_versions[0]
|
||||
return target_version
|
||||
|
||||
return model_versions[0]
|
||||
|
||||
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
||||
for file_info in version_entry.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
hashes = file_info.get('hashes', {})
|
||||
model_hash = hashes.get('SHA256')
|
||||
if model_hash:
|
||||
return model_hash
|
||||
return None
|
||||
|
||||
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
|
||||
version = copy.deepcopy(version_entry)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('nsfw'),
|
||||
'poi': model_data.get('poi')
|
||||
}
|
||||
return version
|
||||
|
||||
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
|
||||
model_info['description'] = model_data.get("description")
|
||||
model_info['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
license_payload = resolve_license_payload(model_data)
|
||||
for field, value in license_payload.items():
|
||||
model_info[field] = value
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from Civitai
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
version_id: The Civitai model version ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], int]: A tuple containing:
|
||||
- A dictionary with model metadata or None if not found
|
||||
- The HTTP status code from the request
|
||||
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
|
||||
- The model version data or None if not found
|
||||
- An error message if there was an error, or None on success
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
headers = self._get_request_headers()
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
status_code = response.status
|
||||
|
||||
if status_code != 200:
|
||||
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
|
||||
return None, status_code
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": data.get("description") or "No model description available",
|
||||
"tags": data.get("tags", [])
|
||||
}
|
||||
|
||||
if metadata["description"] or metadata["tags"]:
|
||||
return metadata, status_code
|
||||
else:
|
||||
logger.warning(f"No metadata found for model {model_id}")
|
||||
return None, status_code
|
||||
|
||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||
success, result = await self._make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
||||
self._remove_comfy_metadata(result)
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
return None, error_msg
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
||||
return None, str(result)
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||
return None, 0
|
||||
error_msg = f"Error fetching model version info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None, error_msg
|
||||
|
||||
# Keep old method for backward compatibility, delegating to the new one
|
||||
async def get_model_description(self, model_id: str) -> Optional[str]:
|
||||
"""Fetch the model description from Civitai API (Legacy method)"""
|
||||
metadata, _ = await self.get_model_metadata(model_id)
|
||||
return metadata.get("description") if metadata else None
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
async def close(self):
|
||||
"""Close the session if it exists"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
||||
"""Get hash from Civitai API"""
|
||||
Args:
|
||||
image_id: The Civitai image ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The image data or None if not found
|
||||
"""
|
||||
try:
|
||||
if not self._session:
|
||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||
|
||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||
success, result = await self._make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if success:
|
||||
if result and "items" in result and len(result["items"]) > 0:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return result["items"][0]
|
||||
logger.warning(f"No image found with ID: {image_id}")
|
||||
return None
|
||||
|
||||
version_info = await self._session.get(f"{self.base_url}/model-versions/{model_version_id}")
|
||||
|
||||
if not version_info or not version_info.json().get('files'):
|
||||
return None
|
||||
|
||||
# Get hash from the first file
|
||||
for file_info in version_info.json().get('files', []):
|
||||
if file_info.get('hashes', {}).get('SHA256'):
|
||||
# Convert hash to lowercase to standardize
|
||||
hash_value = file_info['hashes']['SHA256'].lower()
|
||||
return hash_value
|
||||
|
||||
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
||||
return None
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting hash from Civitai: {e}")
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Fetch all models for a specific Civitai user."""
|
||||
if not username:
|
||||
return None
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}/models?username={username}"
|
||||
success, result = await self._make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||
return None
|
||||
|
||||
items = result.get("items") if isinstance(result, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
|
||||
for model in items:
|
||||
versions = model.get("modelVersions")
|
||||
if not isinstance(versions, list):
|
||||
continue
|
||||
for version in versions:
|
||||
self._remove_comfy_metadata(version)
|
||||
|
||||
return items
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error fetching models for %s: %s", username, exc)
|
||||
return None
|
||||
|
||||
178
py/services/download_coordinator.py
Normal file
178
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Service wrapper for coordinating download lifecycle events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from .downloader import DownloadProgress
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadCoordinator:
|
||||
"""Manage download scheduling, cancellation and introspection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager_factory: Callable[[], Awaitable],
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._download_manager_factory = download_manager_factory
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download using the provided payload."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
|
||||
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||
payload.setdefault("download_id", download_id)
|
||||
|
||||
async def progress_callback(progress: Any, snapshot: Optional[DownloadProgress] = None) -> None:
|
||||
percent = 0.0
|
||||
metrics: Optional[DownloadProgress] = None
|
||||
|
||||
if isinstance(progress, DownloadProgress):
|
||||
metrics = progress
|
||||
percent = progress.percent_complete
|
||||
elif isinstance(snapshot, DownloadProgress):
|
||||
metrics = snapshot
|
||||
percent = snapshot.percent_complete
|
||||
else:
|
||||
try:
|
||||
percent = float(progress)
|
||||
except (TypeError, ValueError):
|
||||
percent = 0.0
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"status": "progress",
|
||||
"progress": round(percent),
|
||||
"download_id": download_id,
|
||||
}
|
||||
|
||||
if metrics is not None:
|
||||
payload.update(
|
||||
{
|
||||
"bytes_downloaded": metrics.bytes_downloaded,
|
||||
"total_bytes": metrics.total_bytes,
|
||||
"bytes_per_second": metrics.bytes_per_second,
|
||||
}
|
||||
)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
payload,
|
||||
)
|
||||
|
||||
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||
model_version_id = self._parse_optional_int(
|
||||
payload.get("model_version_id"), "model_version_id"
|
||||
)
|
||||
|
||||
if model_id is None and model_version_id is None:
|
||||
raise ValueError(
|
||||
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||
)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=payload.get("model_root"),
|
||||
relative_path=payload.get("relative_path", ""),
|
||||
use_default_paths=payload.get("use_default_paths", False),
|
||||
progress_callback=progress_callback,
|
||||
download_id=download_id,
|
||||
source=payload.get("source"),
|
||||
)
|
||||
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Cancel an active download and emit a broadcast event."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "cancelled",
|
||||
"progress": 0,
|
||||
"download_id": download_id,
|
||||
"message": "Download cancelled by user",
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Pause an active download and notify listeners."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.pause_download(download_id)
|
||||
|
||||
if result.get("success"):
|
||||
cached_progress = self._ws_manager.get_download_progress(download_id) or {}
|
||||
payload: Dict[str, Any] = {
|
||||
"status": "paused",
|
||||
"progress": cached_progress.get("progress", 0),
|
||||
"download_id": download_id,
|
||||
"message": "Download paused by user",
|
||||
}
|
||||
|
||||
for field in ("bytes_downloaded", "total_bytes", "bytes_per_second"):
|
||||
if field in cached_progress:
|
||||
payload[field] = cached_progress[field]
|
||||
|
||||
payload["bytes_per_second"] = 0.0
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(download_id, payload)
|
||||
|
||||
return result
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Resume a paused download and notify listeners."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.resume_download(download_id)
|
||||
|
||||
if result.get("success"):
|
||||
cached_progress = self._ws_manager.get_download_progress(download_id) or {}
|
||||
payload: Dict[str, Any] = {
|
||||
"status": "downloading",
|
||||
"progress": cached_progress.get("progress", 0),
|
||||
"download_id": download_id,
|
||||
"message": "Download resumed by user",
|
||||
}
|
||||
|
||||
for field in ("bytes_downloaded", "total_bytes"):
|
||||
if field in cached_progress:
|
||||
payload[field] = cached_progress[field]
|
||||
|
||||
payload["bytes_per_second"] = cached_progress.get("bytes_per_second", 0.0)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(download_id, payload)
|
||||
|
||||
return result
|
||||
|
||||
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||
"""Return the active download map from the underlying manager."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
return await download_manager.get_active_downloads()
|
||||
|
||||
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||
"""Parse an optional integer from user input."""
|
||||
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
850
py/services/downloader.py
Normal file
850
py/services/downloader.py
Normal file
@@ -0,0 +1,850 @@
|
||||
"""
|
||||
Unified download manager for all HTTP/HTTPS downloads in the application.
|
||||
|
||||
This module provides a centralized download service with:
|
||||
- Singleton pattern for global session management
|
||||
- Support for authenticated downloads (e.g., CivitAI API key)
|
||||
- Resumable downloads with automatic retry
|
||||
- Progress tracking and callbacks
|
||||
- Optimized connection pooling and timeouts
|
||||
- Unified error handling and logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadProgress:
|
||||
"""Snapshot of a download transfer at a moment in time."""
|
||||
|
||||
percent_complete: float
|
||||
bytes_downloaded: int
|
||||
total_bytes: Optional[int]
|
||||
bytes_per_second: float
|
||||
timestamp: float
|
||||
|
||||
|
||||
class DownloadStreamControl:
|
||||
"""Synchronize pause/resume requests and reconnect hints for a download."""
|
||||
|
||||
def __init__(self, *, stall_timeout: Optional[float] = None) -> None:
|
||||
self._event = asyncio.Event()
|
||||
self._event.set()
|
||||
self._reconnect_requested = False
|
||||
self.last_progress_timestamp: Optional[float] = None
|
||||
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
return not self._event.is_set()
|
||||
|
||||
def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
def pause(self) -> None:
|
||||
self.clear()
|
||||
|
||||
def resume(self, *, force_reconnect: bool = False) -> None:
|
||||
if force_reconnect:
|
||||
self._reconnect_requested = True
|
||||
self.set()
|
||||
|
||||
def request_reconnect(self) -> None:
|
||||
self._reconnect_requested = True
|
||||
self.set()
|
||||
|
||||
def has_reconnect_request(self) -> bool:
|
||||
return self._reconnect_requested
|
||||
|
||||
def consume_reconnect_request(self) -> bool:
|
||||
reconnect = self._reconnect_requested
|
||||
self._reconnect_requested = False
|
||||
return reconnect
|
||||
|
||||
def mark_progress(self, timestamp: Optional[float] = None) -> None:
|
||||
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
|
||||
self._reconnect_requested = False
|
||||
|
||||
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
|
||||
if self.last_progress_timestamp is None:
|
||||
return None
|
||||
reference = now if now is not None else datetime.now().timestamp()
|
||||
return max(0.0, reference - self.last_progress_timestamp)
|
||||
|
||||
def update_stall_timeout(self, stall_timeout: float) -> None:
|
||||
self.stall_timeout = float(stall_timeout)
|
||||
|
||||
|
||||
class DownloadRestartRequested(Exception):
|
||||
"""Raised when a caller explicitly requests a fresh HTTP stream."""
|
||||
|
||||
|
||||
class DownloadStalledError(Exception):
|
||||
"""Raised when download progress stalls beyond the configured timeout."""
|
||||
|
||||
|
||||
class Downloader:
|
||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of Downloader"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the downloader with optimal settings"""
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# Session management
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None # Store proxy URL for current session
|
||||
|
||||
# Configuration
|
||||
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
|
||||
self.max_retries = 5
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
self.stall_timeout = self._resolve_stall_timeout()
|
||||
|
||||
# Default headers
|
||||
self.default_headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
||||
# Explicitly request uncompressed payloads so aiohttp doesn't need optional
|
||||
# decoders (e.g. zstandard) that may be missing in runtime environments.
|
||||
'Accept-Encoding': 'identity',
|
||||
}
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create the global aiohttp session with optimized settings"""
|
||||
if self._session is None or self._should_refresh_session():
|
||||
await self._create_session()
|
||||
return self._session
|
||||
|
||||
@property
|
||||
def proxy_url(self) -> Optional[str]:
|
||||
"""Get the current proxy URL (initialize if needed)"""
|
||||
if not hasattr(self, '_proxy_url'):
|
||||
self._proxy_url = None
|
||||
return self._proxy_url
|
||||
|
||||
def _resolve_stall_timeout(self) -> float:
|
||||
"""Determine the stall timeout from settings or environment."""
|
||||
default_timeout = 120.0
|
||||
settings_timeout = None
|
||||
|
||||
try:
|
||||
settings_manager = get_settings_manager()
|
||||
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to read stall timeout from settings: %s", exc)
|
||||
|
||||
raw_value = (
|
||||
settings_timeout
|
||||
if settings_timeout not in (None, "")
|
||||
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
|
||||
)
|
||||
|
||||
try:
|
||||
timeout_value = float(raw_value)
|
||||
except (TypeError, ValueError):
|
||||
timeout_value = default_timeout
|
||||
|
||||
return max(30.0, timeout_value)
|
||||
|
||||
def _should_refresh_session(self) -> bool:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
return True
|
||||
|
||||
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
|
||||
return True
|
||||
|
||||
# Refresh if session is older than timeout
|
||||
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _create_session(self):
|
||||
"""Create a new aiohttp session with optimized settings"""
|
||||
# Close existing session if any
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
|
||||
# Check for app-level proxy settings
|
||||
proxy_url = None
|
||||
settings_manager = get_settings_manager()
|
||||
if settings_manager.get('proxy_enabled', False):
|
||||
proxy_host = settings_manager.get('proxy_host', '').strip()
|
||||
proxy_port = settings_manager.get('proxy_port', '').strip()
|
||||
proxy_type = settings_manager.get('proxy_type', 'http').lower()
|
||||
proxy_username = settings_manager.get('proxy_username', '').strip()
|
||||
proxy_password = settings_manager.get('proxy_password', '').strip()
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
# Build proxy URL
|
||||
if proxy_username and proxy_password:
|
||||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||
else:
|
||||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
logger.debug(f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}")
|
||||
logger.debug("Proxy mode: app-level proxy is active.")
|
||||
else:
|
||||
logger.debug("Proxy mode: system-level proxy (trust_env) will be used if configured in environment.")
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=8, # Concurrent connections
|
||||
ttl_dns_cache=300, # DNS cache timeout
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=None, # No total timeout for large downloads
|
||||
connect=60, # Connection timeout
|
||||
sock_read=300 # 5 minute socket read timeout
|
||||
)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=proxy_url is None, # Only use system proxy if no app-level proxy is set
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Store proxy URL for use in requests
|
||||
self._proxy_url = proxy_url
|
||||
self._session_created_at = datetime.now()
|
||||
|
||||
logger.debug("Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s", bool(proxy_url), proxy_url is None)
|
||||
|
||||
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
|
||||
"""Get headers with optional authentication"""
|
||||
headers = self.default_headers.copy()
|
||||
|
||||
if use_auth:
|
||||
# Add CivitAI API key if available
|
||||
settings_manager = get_settings_manager()
|
||||
api_key = settings_manager.get('civitai_api_key')
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
return headers
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
allow_resume: bool = True,
|
||||
pause_event: Optional[DownloadStreamControl] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Download a file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
save_path: Full path where the file should be saved
|
||||
progress_callback: Optional callback for progress updates (0-100)
|
||||
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||
custom_headers: Additional headers to include in request
|
||||
allow_resume: Whether to support resumable downloads
|
||||
pause_event: Optional stream control used to pause/resume and request reconnects
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
retry_count = 0
|
||||
part_path = save_path + '.part' if allow_resume else save_path
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Get existing file size for resume
|
||||
resume_offset = 0
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
total_size = 0
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_file] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[download_file] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Add Range header for resume if we have partial data
|
||||
request_headers = headers.copy()
|
||||
if allow_resume and resume_offset > 0:
|
||||
request_headers['Range'] = f'bytes={resume_offset}-'
|
||||
|
||||
# Disable compression for better chunked downloads
|
||||
request_headers['Accept-Encoding'] = 'identity'
|
||||
|
||||
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
|
||||
if resume_offset > 0:
|
||||
logger.debug(f"Requesting range from byte {resume_offset}")
|
||||
|
||||
async with session.get(url, headers=request_headers, allow_redirects=True, proxy=self.proxy_url) as response:
|
||||
# Handle different response codes
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning("Server doesn't support range requests, restarting download")
|
||||
resume_offset = 0
|
||||
if os.path.exists(part_path):
|
||||
os.remove(part_path)
|
||||
elif response.status == 206:
|
||||
# Partial content response (resume successful)
|
||||
content_range = response.headers.get('Content-Range')
|
||||
if content_range:
|
||||
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
||||
range_parts = content_range.split('/')
|
||||
if len(range_parts) == 2:
|
||||
total_size = int(range_parts[1])
|
||||
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable - file might be complete or corrupted
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
part_size = os.path.getsize(part_path)
|
||||
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
||||
# Try to get actual file size
|
||||
head_response = await session.head(url, headers=headers, proxy=self.proxy_url)
|
||||
if head_response.status == 200:
|
||||
actual_size = int(head_response.headers.get('content-length', 0))
|
||||
if part_size == actual_size:
|
||||
# File is complete, just rename it
|
||||
if allow_resume:
|
||||
os.rename(part_path, save_path)
|
||||
if progress_callback:
|
||||
await self._dispatch_progress_callback(
|
||||
progress_callback,
|
||||
DownloadProgress(
|
||||
percent_complete=100.0,
|
||||
bytes_downloaded=part_size,
|
||||
total_bytes=actual_size,
|
||||
bytes_per_second=0.0,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
),
|
||||
)
|
||||
return True, save_path
|
||||
# Remove corrupted part file and restart
|
||||
os.remove(part_path)
|
||||
resume_offset = 0
|
||||
continue
|
||||
elif response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
return False, "Invalid or missing API key, or early access restriction."
|
||||
elif response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
elif response.status == 404:
|
||||
logger.warning(f"Resource not found: {url} (Status 404)")
|
||||
return False, "File not found - the download link may be invalid or expired."
|
||||
else:
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
# Get total file size for progress calculation (if not set from Content-Range)
|
||||
if total_size == 0:
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
if response.status == 206:
|
||||
# For partial content, add the offset to get total file size
|
||||
total_size += resume_offset
|
||||
|
||||
current_size = resume_offset
|
||||
last_progress_report_time = datetime.now()
|
||||
progress_samples: deque[tuple[datetime, int]] = deque()
|
||||
progress_samples.append((last_progress_report_time, current_size))
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
# Stream download to file with progress updates
|
||||
loop = asyncio.get_running_loop()
|
||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||
control = pause_event
|
||||
|
||||
if control is not None:
|
||||
control.update_stall_timeout(self.stall_timeout)
|
||||
|
||||
with open(part_path, mode) as f:
|
||||
while True:
|
||||
active_stall_timeout = control.stall_timeout if control else self.stall_timeout
|
||||
|
||||
if control is not None:
|
||||
if control.is_paused():
|
||||
await control.wait()
|
||||
resume_time = datetime.now()
|
||||
last_progress_report_time = resume_time
|
||||
if control.consume_reconnect_request():
|
||||
raise DownloadRestartRequested(
|
||||
"Reconnect requested after resume"
|
||||
)
|
||||
elif control.consume_reconnect_request():
|
||||
raise DownloadRestartRequested("Reconnect requested")
|
||||
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
response.content.read(self.chunk_size),
|
||||
timeout=active_stall_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.warning(
|
||||
"Download stalled for %.1f seconds without progress from %s",
|
||||
active_stall_timeout,
|
||||
url,
|
||||
)
|
||||
raise DownloadStalledError(
|
||||
f"No data received for {active_stall_timeout:.1f} seconds"
|
||||
) from exc
|
||||
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
# Run blocking file write in executor
|
||||
await loop.run_in_executor(None, f.write, chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
now = datetime.now()
|
||||
if control is not None:
|
||||
control.mark_progress(timestamp=now.timestamp())
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and time_diff >= 1.0:
|
||||
progress_samples.append((now, current_size))
|
||||
cutoff = now - timedelta(seconds=5)
|
||||
while progress_samples and progress_samples[0][0] < cutoff:
|
||||
progress_samples.popleft()
|
||||
|
||||
percent = (current_size / total_size) * 100 if total_size else 0.0
|
||||
bytes_per_second = 0.0
|
||||
if len(progress_samples) >= 2:
|
||||
first_time, first_bytes = progress_samples[0]
|
||||
last_time, last_bytes = progress_samples[-1]
|
||||
elapsed = (last_time - first_time).total_seconds()
|
||||
if elapsed > 0:
|
||||
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
||||
|
||||
progress_snapshot = DownloadProgress(
|
||||
percent_complete=percent,
|
||||
bytes_downloaded=current_size,
|
||||
total_bytes=total_size or None,
|
||||
bytes_per_second=bytes_per_second,
|
||||
timestamp=now.timestamp(),
|
||||
)
|
||||
|
||||
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size integrity before finalizing
|
||||
final_size = os.path.getsize(part_path) if os.path.exists(part_path) else 0
|
||||
expected_size = total_size if total_size > 0 else None
|
||||
|
||||
integrity_error: Optional[str] = None
|
||||
if final_size <= 0:
|
||||
integrity_error = "Downloaded file is empty"
|
||||
elif expected_size is not None and final_size != expected_size:
|
||||
integrity_error = (
|
||||
f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
|
||||
)
|
||||
|
||||
if integrity_error is not None:
|
||||
logger.error(
|
||||
"Download integrity check failed for %s: %s",
|
||||
save_path,
|
||||
integrity_error,
|
||||
)
|
||||
|
||||
# Remove the corrupted payload so future attempts start fresh
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.remove(part_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete corrupted download %s: %s",
|
||||
part_path,
|
||||
remove_error,
|
||||
)
|
||||
if part_path != save_path and os.path.exists(save_path):
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except OSError as remove_error:
|
||||
logger.warning(
|
||||
"Failed to delete target file %s after integrity error: %s",
|
||||
save_path,
|
||||
remove_error,
|
||||
)
|
||||
|
||||
retry_count += 1
|
||||
if retry_count <= self.max_retries:
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(
|
||||
"Retrying download in %s seconds due to integrity check failure",
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
resume_offset = 0
|
||||
total_size = 0
|
||||
await self._create_session()
|
||||
continue
|
||||
|
||||
return False, integrity_error
|
||||
|
||||
# Atomically rename .part to final file (only if using resume)
|
||||
if allow_resume and part_path != save_path:
|
||||
max_rename_attempts = 5
|
||||
rename_attempt = 0
|
||||
rename_success = False
|
||||
|
||||
while rename_attempt < max_rename_attempts and not rename_success:
|
||||
try:
|
||||
# If the destination file exists, remove it first (Windows safe)
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
os.rename(part_path, save_path)
|
||||
rename_success = True
|
||||
except PermissionError as e:
|
||||
rename_attempt += 1
|
||||
if rename_attempt < max_rename_attempts:
|
||||
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||
return False, f"Failed to finalize download: {str(e)}"
|
||||
|
||||
final_size = os.path.getsize(save_path)
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
final_snapshot = DownloadProgress(
|
||||
percent_complete=100.0,
|
||||
bytes_downloaded=final_size,
|
||||
total_bytes=total_size or final_size,
|
||||
bytes_per_second=0.0,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
await self._dispatch_progress_callback(progress_callback, final_snapshot)
|
||||
|
||||
|
||||
return True, save_path
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientPayloadError,
|
||||
aiohttp.ServerDisconnectedError,
|
||||
asyncio.TimeoutError,
|
||||
DownloadStalledError,
|
||||
DownloadRestartRequested,
|
||||
) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Update resume offset for next attempt
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Will resume from byte {resume_offset}")
|
||||
|
||||
# Refresh session to get new connection
|
||||
await self._create_session()
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for download: {e}")
|
||||
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
return False, f"Download failed after {self.max_retries + 1} attempts"
|
||||
|
||||
async def _dispatch_progress_callback(
|
||||
self,
|
||||
progress_callback: Callable[..., Awaitable[None]],
|
||||
snapshot: DownloadProgress,
|
||||
) -> None:
|
||||
"""Invoke a progress callback while preserving backward compatibility."""
|
||||
|
||||
try:
|
||||
result = progress_callback(snapshot, snapshot)
|
||||
except TypeError:
|
||||
result = progress_callback(snapshot.percent_complete)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
elif hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
async def download_to_memory(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
return_headers: bool = False
|
||||
) -> Tuple[bool, Union[bytes, str], Optional[Dict]]:
|
||||
"""
|
||||
Download a file to memory (for small files like preview images)
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
return_headers: Whether to return response headers along with content
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_to_memory] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[download_to_memory] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.get(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
if return_headers:
|
||||
return True, content, dict(response.headers)
|
||||
else:
|
||||
return True, content, None
|
||||
elif response.status == 401:
|
||||
error_msg = "Unauthorized access - invalid or missing API key"
|
||||
return False, error_msg, None
|
||||
elif response.status == 403:
|
||||
error_msg = "Access forbidden"
|
||||
return False, error_msg, None
|
||||
elif response.status == 404:
|
||||
error_msg = "File not found"
|
||||
return False, error_msg, None
|
||||
else:
|
||||
error_msg = f"Download failed with status {response.status}"
|
||||
return False, error_msg, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||
return False, str(e), None
|
||||
|
||||
async def get_response_headers(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Get response headers without downloading the full content
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[get_response_headers] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[get_response_headers] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.head(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
if response.status == 200:
|
||||
return True, dict(response.headers)
|
||||
else:
|
||||
return False, f"Head request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting headers from {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def make_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Make a generic HTTP request and return JSON response
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: Request URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
**kwargs: Additional arguments for aiohttp request
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[make_request] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Add proxy to kwargs if not already present
|
||||
if 'proxy' not in kwargs:
|
||||
kwargs['proxy'] = self.proxy_url
|
||||
|
||||
async with session.request(method, url, headers=headers, **kwargs) as response:
|
||||
if response.status == 200:
|
||||
# Try to parse as JSON, fall back to text
|
||||
try:
|
||||
data = await response.json()
|
||||
return True, data
|
||||
except:
|
||||
text = await response.text()
|
||||
return True, text
|
||||
elif response.status == 401:
|
||||
return False, "Unauthorized access - invalid or missing API key"
|
||||
elif response.status == 403:
|
||||
return False, "Access forbidden"
|
||||
elif response.status == 404:
|
||||
return False, "Resource not found"
|
||||
elif response.status == 429:
|
||||
retry_after = self._extract_retry_after(response.headers)
|
||||
error_msg = "Request rate limited"
|
||||
logger.warning(
|
||||
"Rate limit encountered for %s %s; retry_after=%s",
|
||||
method,
|
||||
url,
|
||||
retry_after,
|
||||
)
|
||||
return False, RateLimitError(
|
||||
error_msg,
|
||||
retry_after=retry_after,
|
||||
)
|
||||
else:
|
||||
return False, f"Request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making {method} request to {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None
|
||||
logger.debug("Closed HTTP session")
|
||||
|
||||
async def refresh_session(self):
|
||||
"""Force refresh the HTTP session (useful when proxy settings change)"""
|
||||
await self._create_session()
|
||||
logger.info("HTTP session refreshed due to settings change")
|
||||
|
||||
@staticmethod
|
||||
def _extract_retry_after(headers) -> Optional[float]:
|
||||
"""Parse the Retry-After header into seconds."""
|
||||
if not headers:
|
||||
return None
|
||||
|
||||
header_value = headers.get("Retry-After")
|
||||
if not header_value:
|
||||
return None
|
||||
|
||||
header_value = header_value.strip()
|
||||
if not header_value:
|
||||
return None
|
||||
|
||||
if header_value.isdigit():
|
||||
try:
|
||||
seconds = float(header_value)
|
||||
except ValueError:
|
||||
return None
|
||||
return max(0.0, seconds)
|
||||
|
||||
try:
|
||||
retry_datetime = parsedate_to_datetime(header_value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if retry_datetime.tzinfo is None:
|
||||
return None
|
||||
|
||||
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
|
||||
return max(0.0, delta.total_seconds())
|
||||
|
||||
|
||||
# Global instance accessor
|
||||
async def get_downloader() -> Downloader:
|
||||
"""Get the global downloader instance"""
|
||||
return await Downloader.get_instance()
|
||||
26
py/services/embedding_scanner.py
Normal file
26
py/services/embedding_scanner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingScanner(ModelScanner):
|
||||
"""Service for scanning and managing embedding files"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
super().__init__(
|
||||
model_type="embedding",
|
||||
model_class=EmbeddingMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get embedding root directories"""
|
||||
return config.embeddings_roots
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user