mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
Compare commits
1132 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12c88835f2 | ||
|
|
6f4453aaf3 | ||
|
|
4b4b8fe3c1 | ||
|
|
49e7c2e9f5 | ||
|
|
4653c273e3 | ||
|
|
ae145de2f2 | ||
|
|
dde7cf71c6 | ||
|
|
219cd242db | ||
|
|
e5b712c082 | ||
|
|
4d2c60d59b | ||
|
|
1d2c1b114b | ||
|
|
2bde936d05 | ||
|
|
cd3e32bf4b | ||
|
|
454536d631 | ||
|
|
656f1755fd | ||
|
|
8aa76ce5c1 | ||
|
|
49fa37f00d | ||
|
|
9f83548cf3 | ||
|
|
6054d95e85 | ||
|
|
8c9bb35824 | ||
|
|
3eacf9558a | ||
|
|
fee37172b4 | ||
|
|
e128c80eb1 | ||
|
|
5cc735ed57 | ||
|
|
43fcce6361 | ||
|
|
49b7126278 | ||
|
|
679cfb5c69 | ||
|
|
50616bc680 | ||
|
|
aaad270822 | ||
|
|
bd10280736 | ||
|
|
d477050239 | ||
|
|
85f79cd8d1 | ||
|
|
613cd81152 | ||
|
|
e0aba6c49a | ||
|
|
d78bcf2494 | ||
|
|
f7cffd2eba | ||
|
|
0d0b91aa80 | ||
|
|
42872e6d2d | ||
|
|
b91f06405d | ||
|
|
dac4c688d6 | ||
|
|
097a68ad18 | ||
|
|
4a98710db0 | ||
|
|
d033a374dd | ||
|
|
6aa23fe36a | ||
|
|
3220cfb79c | ||
|
|
b92e7aa446 | ||
|
|
c3b9c73541 | ||
|
|
81c6672880 | ||
|
|
08baf884d3 | ||
|
|
1c4096f3d5 | ||
|
|
66a3f3f59a | ||
|
|
624df1328b | ||
|
|
c063854b51 | ||
|
|
8cf99dd928 | ||
|
|
c07e885725 | ||
|
|
21772feadd | ||
|
|
2d00cfdd31 | ||
|
|
49e03d658b | ||
|
|
fec85bcc08 | ||
|
|
0e93a6bcb0 | ||
|
|
7e20f738fb | ||
|
|
24090e6077 | ||
|
|
1022b07f64 | ||
|
|
4faf912c6f | ||
|
|
56e4b24b07 | ||
|
|
12295d2fdc | ||
|
|
6261f7d18d | ||
|
|
9e1a2e3bb7 | ||
|
|
40cbb2155c | ||
|
|
a8d7070832 | ||
|
|
ab7266f3a4 | ||
|
|
3053b13fcb | ||
|
|
f3544b3471 | ||
|
|
1610048974 | ||
|
|
fc6f1bf95b | ||
|
|
67b274c1b2 | ||
|
|
fb0d6b5641 | ||
|
|
d30fbeb286 | ||
|
|
46e430ebbb | ||
|
|
bc4cd45fcb | ||
|
|
bdc86ddf15 | ||
|
|
ded17c1479 | ||
|
|
933e2fc01d | ||
|
|
1cddeee264 | ||
|
|
183c000080 | ||
|
|
adf7b6d4b2 | ||
|
|
0566d50346 | ||
|
|
4275dc3003 | ||
|
|
30956aeefc | ||
|
|
64e1dd3dd6 | ||
|
|
0dc4b6f728 | ||
|
|
86074c87d7 | ||
|
|
6f9245df01 | ||
|
|
4540e47055 | ||
|
|
4bb8981e78 | ||
|
|
c49be91aa0 | ||
|
|
2b847039d4 | ||
|
|
1147725fd7 | ||
|
|
26891e12a4 | ||
|
|
2f7e44a76f | ||
|
|
9366d3d2d0 | ||
|
|
6b606a5cc8 | ||
|
|
e5339c178a | ||
|
|
1a76f74482 | ||
|
|
13f13eb095 | ||
|
|
125fdecd61 | ||
|
|
d05076d258 | ||
|
|
00b77581fc | ||
|
|
897787d17c | ||
|
|
d5a280cf2b | ||
|
|
a0c2d9b5ad | ||
|
|
e713bd1ca2 | ||
|
|
beb8ff1dd1 | ||
|
|
6a8f0867d9 | ||
|
|
51ad1c9a33 | ||
|
|
34872eb612 | ||
|
|
8b4e3128ff | ||
|
|
c66cbc800b | ||
|
|
21941521a0 | ||
|
|
0d33884052 | ||
|
|
415df49377 | ||
|
|
f5f45002c7 | ||
|
|
1edf7126bb | ||
|
|
a1a55a1002 | ||
|
|
45f5cb46bd | ||
|
|
1b5e608a27 | ||
|
|
a7df8ae15c | ||
|
|
47ce0d0fe2 | ||
|
|
b220e288d0 | ||
|
|
1fc8b45b68 | ||
|
|
62f06302f0 | ||
|
|
3e5cb223f3 | ||
|
|
4ee5b7481c | ||
|
|
e104b78c01 | ||
|
|
ba1ac58721 | ||
|
|
a4fbeb6295 | ||
|
|
68f8871403 | ||
|
|
6fd74952b7 | ||
|
|
1ea468cfc4 | ||
|
|
14721c265f | ||
|
|
821827a375 | ||
|
|
9ba3e2c204 | ||
|
|
d287883671 | ||
|
|
ead34818db | ||
|
|
a060010b96 | ||
|
|
76a92ac847 | ||
|
|
74bc490383 | ||
|
|
510d476323 | ||
|
|
1e7257fd53 | ||
|
|
4ff1f51b1c | ||
|
|
74507cef05 | ||
|
|
c23ab04d90 | ||
|
|
d50dde6cf6 | ||
|
|
fcb1fb39be | ||
|
|
b0ef74f802 | ||
|
|
f332aef41d | ||
|
|
1f91a3da8e | ||
|
|
16840c321d | ||
|
|
c109e392ad | ||
|
|
5e69671366 | ||
|
|
52d23d9b75 | ||
|
|
4c4e6d7a7b | ||
|
|
03b6e78705 | ||
|
|
24c01141d7 | ||
|
|
6dc2811af4 | ||
|
|
e6425dce32 | ||
|
|
95e2ff5f1e | ||
|
|
92ac487128 | ||
|
|
3250fa89cb | ||
|
|
7475de366b | ||
|
|
affb507b37 | ||
|
|
3320b80150 | ||
|
|
fb2b69b787 | ||
|
|
29a05f6533 | ||
|
|
9fa3fac973 | ||
|
|
904b0d104a | ||
|
|
1d31dae110 | ||
|
|
476ecb7423 | ||
|
|
4eb67cf6da | ||
|
|
a5a9f7ed83 | ||
|
|
c0b029e228 | ||
|
|
9bebcc9a4b | ||
|
|
ac7d23011c | ||
|
|
491e09b7b5 | ||
|
|
192bc237bf | ||
|
|
f041f4a114 | ||
|
|
2546580377 | ||
|
|
8fbf2ab56d | ||
|
|
ea727aad2e | ||
|
|
5520aecbba | ||
|
|
6b738a4769 | ||
|
|
903a8050b3 | ||
|
|
31b032429d | ||
|
|
2bcf341f04 | ||
|
|
ca6f45b359 | ||
|
|
2a67cec16b | ||
|
|
1800afe31b | ||
|
|
8c6311355d | ||
|
|
91801dff85 | ||
|
|
be594133f0 | ||
|
|
8a538d117e | ||
|
|
8d9118cbee | ||
|
|
b67464ea13 | ||
|
|
33334da0bb | ||
|
|
40ce2baa7b | ||
|
|
1134466cc0 | ||
|
|
92341111ad | ||
|
|
4956d6781f | ||
|
|
63562240c4 | ||
|
|
84d801cf14 | ||
|
|
b56fe4ca68 | ||
|
|
6c83c65e02 | ||
|
|
a83f020fcc | ||
|
|
7f9a3bf272 | ||
|
|
f80e266d02 | ||
|
|
7bef562541 | ||
|
|
b2428f607c | ||
|
|
8303196b57 | ||
|
|
987b8c8742 | ||
|
|
e60a579b85 | ||
|
|
be8edafed0 | ||
|
|
a258a18fa4 | ||
|
|
59010ca431 | ||
|
|
75f3764e6c | ||
|
|
867ffd1163 | ||
|
|
6acccbbb94 | ||
|
|
b2c4efab45 | ||
|
|
408a435b71 | ||
|
|
36d3cd93d5 | ||
|
|
b36fea002e | ||
|
|
52acbd954a | ||
|
|
f6709a55c3 | ||
|
|
7b374d747b | ||
|
|
fd480a9360 | ||
|
|
ec8b228867 | ||
|
|
401200050b | ||
|
|
29160bd6e5 | ||
|
|
3c9e402bc0 | ||
|
|
ff4d0f0208 | ||
|
|
f82908221c | ||
|
|
4246908f2e | ||
|
|
f64597afd2 | ||
|
|
975ff2672d | ||
|
|
e90ba31784 | ||
|
|
a4074c93bc | ||
|
|
7a8b7598c7 | ||
|
|
cd0d832f14 | ||
|
|
5b0becaaf2 | ||
|
|
9817bac2fe | ||
|
|
f6bd48cfcd | ||
|
|
01843b8f2b | ||
|
|
94ed81de5e | ||
|
|
0700b8f399 | ||
|
|
d62cff9841 | ||
|
|
083f4805b2 | ||
|
|
8e5bfd379e | ||
|
|
2366f143d8 | ||
|
|
e997f5bc1b | ||
|
|
842beec7cc | ||
|
|
d2268fc9e0 | ||
|
|
a98e26139f | ||
|
|
522a3ea88b | ||
|
|
d7949fbc30 | ||
|
|
6df083a1d5 | ||
|
|
4dc80e7f6e | ||
|
|
c2a8508513 | ||
|
|
159193ef43 | ||
|
|
1f37ffb105 | ||
|
|
919fed05c5 | ||
|
|
1814f83bee | ||
|
|
1823840456 | ||
|
|
623c28bfc3 | ||
|
|
3079131337 | ||
|
|
a34ade0120 | ||
|
|
e9ada70088 | ||
|
|
597cc48248 | ||
|
|
ec3f857ef1 | ||
|
|
383b4de539 | ||
|
|
1bf9326604 | ||
|
|
d9f5459d46 | ||
|
|
e45a1b1e19 | ||
|
|
331ad8f644 | ||
|
|
52fa88b04c | ||
|
|
8895a64d24 | ||
|
|
fdec535559 | ||
|
|
6c5559ae2d | ||
|
|
9f54622b17 | ||
|
|
03b6f4b378 | ||
|
|
af4cbe2332 | ||
|
|
141f72963a | ||
|
|
3d3c66e12f | ||
|
|
ee84571bdb | ||
|
|
6500936aad | ||
|
|
32d2b6c013 | ||
|
|
05df40977d | ||
|
|
5d7a1dcde5 | ||
|
|
9c45d9db6c | ||
|
|
ca692ed0f2 | ||
|
|
af499565d3 | ||
|
|
fe2d7e3a9e | ||
|
|
9f69822221 | ||
|
|
bb43f047c2 | ||
|
|
2356662492 | ||
|
|
1624a45093 | ||
|
|
dcb9983786 | ||
|
|
83d1828905 | ||
|
|
6a281cf3ee | ||
|
|
ed1cd39a6c | ||
|
|
dda19b3920 | ||
|
|
25139ca922 | ||
|
|
3cd57a582c | ||
|
|
d3903ac655 | ||
|
|
199e374318 | ||
|
|
8375c1413d | ||
|
|
9e268cf016 | ||
|
|
112b3abc26 | ||
|
|
a8331a2357 | ||
|
|
52e3ad08c1 | ||
|
|
8d01d04ef0 | ||
|
|
a141384907 | ||
|
|
b8aa7184bd | ||
|
|
e4195f874d | ||
|
|
d04deff5ca | ||
|
|
20ce0778a0 | ||
|
|
5a0b3470f1 | ||
|
|
a920921570 | ||
|
|
286f4ff384 | ||
|
|
71ddfafa98 | ||
|
|
b7e3e53697 | ||
|
|
16df548b77 | ||
|
|
425c33ae00 | ||
|
|
c9289ed2dc | ||
|
|
96517cbdef | ||
|
|
b03420faac | ||
|
|
65a1aa7ca2 | ||
|
|
3a92e8eaf9 | ||
|
|
a8dc50d64a | ||
|
|
3397cc7d8d | ||
|
|
c3e8131b24 | ||
|
|
f8ca8584ae | ||
|
|
3050bbe260 | ||
|
|
e1dda2795a | ||
|
|
6d8408e626 | ||
|
|
0906271aa9 | ||
|
|
4c33c9d256 | ||
|
|
fa9c78209f | ||
|
|
6678ec8a60 | ||
|
|
854e467c12 | ||
|
|
e6b94c7b21 | ||
|
|
2c6f9d8602 | ||
|
|
c74033b9c0 | ||
|
|
d2b21d27bb | ||
|
|
215272469f | ||
|
|
f7d05ab0f1 | ||
|
|
6f2ad2be77 | ||
|
|
66575c719a | ||
|
|
677a239d53 | ||
|
|
3b96bfe5af | ||
|
|
83be5cfa64 | ||
|
|
6b834c2362 | ||
|
|
7abfc49e08 | ||
|
|
65d5f50088 | ||
|
|
4f1f4ffe3d | ||
|
|
b0c2027a1c | ||
|
|
33c83358b0 | ||
|
|
31223f0526 | ||
|
|
92daadb92c | ||
|
|
fae2e274fd | ||
|
|
342a722991 | ||
|
|
65ec6aacb7 | ||
|
|
9387470c69 | ||
|
|
31f6edf8f0 | ||
|
|
487b062175 | ||
|
|
d8e13de096 | ||
|
|
e8a30088ef | ||
|
|
bf7b07ba74 | ||
|
|
28fe3e7b7a | ||
|
|
c0eff2bb5e | ||
|
|
848c1741fe | ||
|
|
1370b8e8c1 | ||
|
|
82a068e610 | ||
|
|
32f42bafaa | ||
|
|
4081b7f022 | ||
|
|
a5808193a6 | ||
|
|
854ca322c1 | ||
|
|
c1d9b5137a | ||
|
|
f33d5745b3 | ||
|
|
d89c2ca128 | ||
|
|
835584cc85 | ||
|
|
b2ffbe3a68 | ||
|
|
defcc79e6c | ||
|
|
c06d9f84f0 | ||
|
|
fe57a8e156 | ||
|
|
b77105795a | ||
|
|
e2df5fcf27 | ||
|
|
836a64e728 | ||
|
|
08ba0c9f42 | ||
|
|
6fcc6a5299 | ||
|
|
6dd58248c6 | ||
|
|
2786801b71 | ||
|
|
ea29cbeb7a | ||
|
|
3cf9121a8c | ||
|
|
381bd3938a | ||
|
|
e4ce384023 | ||
|
|
12d1857b13 | ||
|
|
0d9003dea4 | ||
|
|
1a3751acfa | ||
|
|
c5a3af2399 | ||
|
|
ea8a64fafc | ||
|
|
981e367bf1 | ||
|
|
a3d6e62035 | ||
|
|
7f205cdcc8 | ||
|
|
e587189880 | ||
|
|
206c1bd69f | ||
|
|
a7d9255c2c | ||
|
|
08265a85ec | ||
|
|
1ed5630464 | ||
|
|
c784615f11 | ||
|
|
26d51b1190 | ||
|
|
d83fad6abc | ||
|
|
692796db46 | ||
|
|
f15c6f33f9 | ||
|
|
dda9eb4d7c | ||
|
|
6f3aeb61e7 | ||
|
|
d6145e633f | ||
|
|
07014d98ce | ||
|
|
e8ccdabe6c | ||
|
|
cf9fd2d5c2 | ||
|
|
bf9aa9356b | ||
|
|
68d00ce289 | ||
|
|
5288021e4f | ||
|
|
4d38add291 | ||
|
|
804808da4a | ||
|
|
298a95432d | ||
|
|
a834fc4b30 | ||
|
|
2c6c9542dd | ||
|
|
a9a7f4c8ec | ||
|
|
ea9370443d | ||
|
|
c2e00b240e | ||
|
|
a2b81ea099 | ||
|
|
ee609e8eac | ||
|
|
e04ef671e9 | ||
|
|
0184dfd7eb | ||
|
|
eccfa0ca54 | ||
|
|
6d3feb4bef | ||
|
|
29d2b5ee4b | ||
|
|
c82fabb67f | ||
|
|
fcfc868e57 | ||
|
|
67b403f8ca | ||
|
|
de06c6b2f6 | ||
|
|
fa444dfb8a | ||
|
|
124002a472 | ||
|
|
0c883433c1 | ||
|
|
bcf3b2cf55 | ||
|
|
357c4e9c08 | ||
|
|
9edfc68e91 | ||
|
|
8c06cb3e80 | ||
|
|
144fa0a6d4 | ||
|
|
25d5a1541e | ||
|
|
a579d36389 | ||
|
|
d766dac341 | ||
|
|
b15ef1bbc6 | ||
|
|
3e52e00597 | ||
|
|
f749dd0d52 | ||
|
|
48a8a42108 | ||
|
|
db7f57a5a4 | ||
|
|
556381b983 | ||
|
|
158d7d5898 | ||
|
|
18844da95d | ||
|
|
7e0df4d718 | ||
|
|
0dbb76e8c8 | ||
|
|
f73b3422a6 | ||
|
|
bd95e802ec | ||
|
|
5de16a78c5 | ||
|
|
6f8e09fcde | ||
|
|
f54d480f03 | ||
|
|
e68b213fb3 | ||
|
|
132334d500 | ||
|
|
a6f04c6d7e | ||
|
|
854e8bf356 | ||
|
|
6ff883d2d3 | ||
|
|
849b97afba | ||
|
|
1bd2635864 | ||
|
|
79ab0f7b6c | ||
|
|
79011bd257 | ||
|
|
c692713ffb | ||
|
|
df9b554ce1 | ||
|
|
277a8e4682 | ||
|
|
acb52dba09 | ||
|
|
8f10765254 | ||
|
|
0653f59473 | ||
|
|
7a4b5a4667 | ||
|
|
49c4a4068b | ||
|
|
40ad590046 | ||
|
|
30374ae3e6 | ||
|
|
ab22d16bad | ||
|
|
971cd56a4a | ||
|
|
d7cb546c5f | ||
|
|
9d8b7344cd | ||
|
|
2d4f6ae7ce | ||
|
|
d9126807b0 | ||
|
|
cad5fb3fba | ||
|
|
afe23ad6b7 | ||
|
|
fc4327087b | ||
|
|
71762d788f | ||
|
|
6472e00fb0 | ||
|
|
4043846767 | ||
|
|
d3b2bc962c | ||
|
|
54f7b64821 | ||
|
|
82a2a6e669 | ||
|
|
6376d60af5 | ||
|
|
b1e2e3831f | ||
|
|
5de1c8aa82 | ||
|
|
63dc5c2bdb | ||
|
|
7f2d1670a0 | ||
|
|
53c8c337fc | ||
|
|
5b4ec1b2a2 | ||
|
|
64dd2ed141 | ||
|
|
eb57e04e95 | ||
|
|
ae905c8630 | ||
|
|
c157e794f0 | ||
|
|
ed9bae6f6a | ||
|
|
9fe1ce19ad | ||
|
|
6148236cbd | ||
|
|
2471eb518a | ||
|
|
8931b41c76 | ||
|
|
7f523f167d | ||
|
|
446b6d6158 | ||
|
|
2ee057e19b | ||
|
|
afc810f21f | ||
|
|
357052a903 | ||
|
|
39d6d8d04a | ||
|
|
888896c0c0 | ||
|
|
ceee482ecc | ||
|
|
d0ed1213d8 | ||
|
|
f6ef428008 | ||
|
|
e726c4f442 | ||
|
|
402318e586 | ||
|
|
b198cc2a6e | ||
|
|
c3dd4da11b | ||
|
|
ba2e42b06e | ||
|
|
fa0902dc74 | ||
|
|
8fcb6083dc | ||
|
|
1ef88140e3 | ||
|
|
aa34c4c84c | ||
|
|
32d12bb334 | ||
|
|
1b2a02cb1a | ||
|
|
2ff11a16c4 | ||
|
|
441af82dbd | ||
|
|
e09c09af6f | ||
|
|
3721fe226f | ||
|
|
8ace0e11cf | ||
|
|
5e249b0b59 | ||
|
|
4889955ecf | ||
|
|
d840fd53da | ||
|
|
a61819cdb3 | ||
|
|
e986fbb5fb | ||
|
|
8f4d575ec8 | ||
|
|
605a06317b | ||
|
|
a7304ccf47 | ||
|
|
374e2bd4b9 | ||
|
|
09a3246ddb | ||
|
|
a615603866 | ||
|
|
1ca05808e1 | ||
|
|
5febc2a805 | ||
|
|
3c047bee58 | ||
|
|
022c6c157a | ||
|
|
fa587d5678 | ||
|
|
afa5a42f5a | ||
|
|
71df8ba3e2 | ||
|
|
8764998e8c | ||
|
|
2cb4f3aac8 | ||
|
|
1ccaf33aac | ||
|
|
cb0a8e0413 | ||
|
|
8674168df4 | ||
|
|
2221653801 | ||
|
|
78bcdcef5d | ||
|
|
672fbe2ac0 | ||
|
|
56a5970b44 | ||
|
|
a66cef7cfe | ||
|
|
c0b1c2e099 | ||
|
|
9e553bb87b | ||
|
|
f966514bc7 | ||
|
|
dc0a49f96d | ||
|
|
65c783c024 | ||
|
|
6395836fbb | ||
|
|
a7207084ef | ||
|
|
27ef1f1e71 | ||
|
|
68fdb14cd6 | ||
|
|
c2af282a85 | ||
|
|
92d48335cb | ||
|
|
78cac2edc2 | ||
|
|
26d105c439 | ||
|
|
7fec107b98 | ||
|
|
eb01ad3af9 | ||
|
|
e0d9880b32 | ||
|
|
e81e96f0ab | ||
|
|
06d5bd259c | ||
|
|
14238b8d62 | ||
|
|
3b51886927 | ||
|
|
a295ff2e06 | ||
|
|
18cdaabf5e | ||
|
|
787e37b7c6 | ||
|
|
4e5c8b2dd0 | ||
|
|
d8ddacde38 | ||
|
|
bb1e42f0d3 | ||
|
|
923669c495 | ||
|
|
7a4139544c | ||
|
|
4d6ea0236b | ||
|
|
e872a06f22 | ||
|
|
647bda2160 | ||
|
|
c1e93d23f3 | ||
|
|
c96550cc68 | ||
|
|
b1015ecdc5 | ||
|
|
f1b928a037 | ||
|
|
16c312c90b | ||
|
|
110ffd0118 | ||
|
|
35ad872419 | ||
|
|
9b943cf2b8 | ||
|
|
9d1b357e64 | ||
|
|
9fc2fb4d17 | ||
|
|
641fa8a3d9 | ||
|
|
add9269706 | ||
|
|
1a01c4a344 | ||
|
|
b4e7feed06 | ||
|
|
4b96c650eb | ||
|
|
107aef3785 | ||
|
|
b49807824f | ||
|
|
e5ef2ef8b5 | ||
|
|
88779ed56c | ||
|
|
8b59fb6adc | ||
|
|
7945647b0b | ||
|
|
2d39b84806 | ||
|
|
e151a19fcf | ||
|
|
99d2ba26b9 | ||
|
|
396924f4cc | ||
|
|
7545312229 | ||
|
|
26f9779fbf | ||
|
|
0bd62eef3a | ||
|
|
e06d15f508 | ||
|
|
aa1ee96bc9 | ||
|
|
355c73512d | ||
|
|
0daf9d92ff | ||
|
|
37de26ce25 | ||
|
|
0eaef7e7a0 | ||
|
|
8063cee3cd | ||
|
|
cbb25b4ac0 | ||
|
|
c62206a157 | ||
|
|
09832141d0 | ||
|
|
bf8e121a10 | ||
|
|
68568073ec | ||
|
|
ec36524c35 | ||
|
|
67acd9fd2c | ||
|
|
f7be5c8d25 | ||
|
|
ceacac75e0 | ||
|
|
bae66f94e8 | ||
|
|
ddf132bd78 | ||
|
|
afb012029f | ||
|
|
651e14c8c3 | ||
|
|
e7c626eb5f | ||
|
|
a0b0d40a19 | ||
|
|
42e3ab9e27 | ||
|
|
6e5f333364 | ||
|
|
f33a9abe60 | ||
|
|
7f1bbdd615 | ||
|
|
d3bf8eaceb | ||
|
|
b9c9d602de | ||
|
|
b25fbd6e24 | ||
|
|
6052608a4e | ||
|
|
a073b82751 | ||
|
|
8250acdfb5 | ||
|
|
8e1f73a34e | ||
|
|
50704bc882 | ||
|
|
35d34e3513 | ||
|
|
ea834f3de6 | ||
|
|
11aedde72f | ||
|
|
488654abc8 | ||
|
|
da1be0dc65 | ||
|
|
d0c728a339 | ||
|
|
66c66c4d9b | ||
|
|
4882721387 | ||
|
|
06a8850c0c | ||
|
|
370aa06c67 | ||
|
|
c9fa0564e7 | ||
|
|
2ba7a0ceba | ||
|
|
276aedfbb9 | ||
|
|
c193c75674 | ||
|
|
a562ba3746 | ||
|
|
2fedd572ff | ||
|
|
db0b49c427 | ||
|
|
03a6f8111c | ||
|
|
925ad7b3e0 | ||
|
|
bf793d5b8b | ||
|
|
64a906ca5e | ||
|
|
99b36442bb | ||
|
|
3c5164d510 | ||
|
|
ec4b5a4d45 | ||
|
|
78e1901779 | ||
|
|
cb539314de | ||
|
|
c7627fe0de | ||
|
|
84bfad7ce5 | ||
|
|
3e06938b05 | ||
|
|
4f712fec14 | ||
|
|
c5c9659c76 | ||
|
|
d6e175c1f1 | ||
|
|
88088e1071 | ||
|
|
958ddbca86 | ||
|
|
6670fd28f4 | ||
|
|
1e59c31de3 | ||
|
|
c966dbbbbc | ||
|
|
af8f5ba04e | ||
|
|
b741ed0b3b | ||
|
|
01ba3c14f8 | ||
|
|
d13b1a83ad | ||
|
|
303477db70 | ||
|
|
311e89e9e7 | ||
|
|
8546cfe714 | ||
|
|
e6f4d84b9a | ||
|
|
ce7e422169 | ||
|
|
e5aec80984 | ||
|
|
6d97817390 | ||
|
|
d516f22159 | ||
|
|
e918c18ca2 | ||
|
|
5dd8d905fa | ||
|
|
1121d1ee6c | ||
|
|
4793f096af | ||
|
|
7b5b4ce082 | ||
|
|
fa08c9c3e4 | ||
|
|
d0d5eb956a | ||
|
|
969f949330 | ||
|
|
9169bbd04d | ||
|
|
99463ad01c | ||
|
|
f1d6b0feda | ||
|
|
e33da50278 | ||
|
|
4034eb3221 | ||
|
|
75a95f0109 | ||
|
|
92fdc16fe6 | ||
|
|
23fa2995c8 | ||
|
|
59aefdff77 | ||
|
|
e92ab9e3cc | ||
|
|
e3bf1f763c | ||
|
|
1c6e9d0b69 | ||
|
|
bfd4eb3e11 | ||
|
|
c9f902a8af | ||
|
|
0b67510ec9 | ||
|
|
b5cd320e8b | ||
|
|
deb25b4987 | ||
|
|
4612da264a | ||
|
|
59b67e1e10 | ||
|
|
5fad936b27 | ||
|
|
e376a45dea | ||
|
|
fd593bb61d | ||
|
|
71b97d5974 | ||
|
|
2b405ae164 | ||
|
|
2fe4736b69 | ||
|
|
184f8ca6cf | ||
|
|
1ff2019dde | ||
|
|
a3d8261686 | ||
|
|
7d0600976e | ||
|
|
e1e6e4f3dc | ||
|
|
fba2853773 | ||
|
|
48df7e1078 | ||
|
|
235dcd5fa6 | ||
|
|
2027db7411 | ||
|
|
611dd33c75 | ||
|
|
ec1c92a714 | ||
|
|
6ac78156ac | ||
|
|
e94b74e92d | ||
|
|
2bbec47f63 | ||
|
|
b5ddf4c953 | ||
|
|
44be75aeef | ||
|
|
2c03759b5d | ||
|
|
2e3da03723 | ||
|
|
6e96fbcda7 | ||
|
|
d1fd5b7f27 | ||
|
|
9dbcc105e7 | ||
|
|
5cd5a82ddc | ||
|
|
88c1892dc9 | ||
|
|
3c1b181675 | ||
|
|
6777dc16ca | ||
|
|
3833647dfe | ||
|
|
b6c47f0cce | ||
|
|
d308c7ac60 | ||
|
|
947c757aa5 | ||
|
|
5ee5bd7d36 | ||
|
|
d9c4ae92cd | ||
|
|
e1efff19f0 | ||
|
|
61f723a1f5 | ||
|
|
b32756932b | ||
|
|
cb5e64d26b | ||
|
|
f36febf10a | ||
|
|
26d9a9caa6 | ||
|
|
cb876cf77e | ||
|
|
4789711910 | ||
|
|
4064980505 | ||
|
|
f9b8f2d22c | ||
|
|
6a95aadc53 | ||
|
|
f9f08f082d | ||
|
|
0817901bef | ||
|
|
ac22172e53 | ||
|
|
fd87fbf31e | ||
|
|
554be0908f | ||
|
|
eaec4e5f13 | ||
|
|
0e7ba27a7d | ||
|
|
c551f5c23b | ||
|
|
5159657ae5 | ||
|
|
d35db7df72 | ||
|
|
2b5399c559 | ||
|
|
9e61bbbd8e | ||
|
|
7ce5857cd5 | ||
|
|
38fbae99fd | ||
|
|
b0a9d44b0c | ||
|
|
b4e22cd375 | ||
|
|
9bc92736a7 | ||
|
|
111b34d05c | ||
|
|
07d9599a2f | ||
|
|
d8194f211d | ||
|
|
51a6374c33 | ||
|
|
aa6c6035b6 | ||
|
|
44b4a7ffbb | ||
|
|
e5bb018d22 | ||
|
|
79b8a6536e | ||
|
|
3de31cd06a | ||
|
|
c579b54d40 | ||
|
|
0a52575e8b | ||
|
|
23c9a98f66 | ||
|
|
796fc33b5b | ||
|
|
dc4c11ddd2 | ||
|
|
d389e4d5d4 | ||
|
|
8cb78ad931 | ||
|
|
85f987d15c | ||
|
|
b12079e0f6 | ||
|
|
dcf5c6167a | ||
|
|
b395d3f487 | ||
|
|
37662cad10 | ||
|
|
aa1673063d | ||
|
|
f51f49eb60 | ||
|
|
54c9bac961 | ||
|
|
e70fd73bdd | ||
|
|
9bb9e7b64d | ||
|
|
f64c03543a | ||
|
|
51374de1a1 | ||
|
|
afcc12f263 | ||
|
|
88c5482366 | ||
|
|
bbf7295c32 | ||
|
|
ca5e23e68c | ||
|
|
eadb1487ae | ||
|
|
1faa70fc77 | ||
|
|
30d7c007de | ||
|
|
f54f6a4402 | ||
|
|
7b41cdec65 | ||
|
|
fb6a652a57 | ||
|
|
ea34d753c1 | ||
|
|
2bc46e708e | ||
|
|
96e3b5b7b3 | ||
|
|
fafbafa5e1 | ||
|
|
be8605d8c6 | ||
|
|
061660d47a | ||
|
|
2ed6dbb344 | ||
|
|
4766b45746 | ||
|
|
0734252e98 | ||
|
|
91b4827c1d | ||
|
|
df6d56ce66 | ||
|
|
f0203c96ab | ||
|
|
bccabe40c0 | ||
|
|
c2f599b4ff | ||
|
|
5fd069d70d | ||
|
|
32d34d1748 | ||
|
|
18eb605605 | ||
|
|
4fdc88e9e1 | ||
|
|
4c69d8d3a8 | ||
|
|
d4b2dd0ec1 | ||
|
|
181f78421b | ||
|
|
8ed38527d0 | ||
|
|
c4c926070d | ||
|
|
ed87411e0d | ||
|
|
4ec2a448ab | ||
|
|
73d01da94e | ||
|
|
df8e02157a | ||
|
|
6e513ed32a | ||
|
|
325ef6327d | ||
|
|
46700e5ad0 | ||
|
|
d1e21fa345 | ||
|
|
cede387783 | ||
|
|
b206427d50 | ||
|
|
47d96e2037 | ||
|
|
e51f7cc1a7 | ||
|
|
40381d4b11 | ||
|
|
76fc9e5a3d | ||
|
|
9822f2c614 | ||
|
|
8854334ab5 | ||
|
|
53080844d2 | ||
|
|
76fd722e33 | ||
|
|
fa27513f76 | ||
|
|
72c6f91130 | ||
|
|
5918f35b8b | ||
|
|
0b11e6e6d0 | ||
|
|
a043b487bd | ||
|
|
3982489e67 | ||
|
|
5f3c515323 | ||
|
|
6e1297d734 | ||
|
|
8f3cbdd257 | ||
|
|
2fc06ae64e | ||
|
|
515aa1d2bd | ||
|
|
ff7a36394a | ||
|
|
5261ab249a | ||
|
|
c3192351da | ||
|
|
ce30d067a6 | ||
|
|
e84a8a72c5 | ||
|
|
10a4fe04d1 | ||
|
|
d5ce6441e3 | ||
|
|
a8d21fb1d6 | ||
|
|
9277d8d8f8 | ||
|
|
0618541527 | ||
|
|
1db49a4dd4 | ||
|
|
3df96034a1 | ||
|
|
e991dc061d | ||
|
|
56670066c7 | ||
|
|
31d27ff3fa | ||
|
|
297ff0dd25 | ||
|
|
b0a5b48fb2 | ||
|
|
ac244e6ad9 | ||
|
|
7393e92b21 | ||
|
|
86810d9f03 | ||
|
|
18aa8d11ad | ||
|
|
fafec56f09 | ||
|
|
129ca9da81 | ||
|
|
cbfb9ac87c | ||
|
|
42309edef4 | ||
|
|
559e57ca46 | ||
|
|
311bf1f157 | ||
|
|
131c3cc324 | ||
|
|
152ec0da0d | ||
|
|
ee04df40c3 | ||
|
|
252e90a633 | ||
|
|
048d486fa6 | ||
|
|
8fdfb68741 | ||
|
|
64c9e4aeca | ||
|
|
08b90e8767 | ||
|
|
0206613f9e | ||
|
|
ae0629628e | ||
|
|
785b2e7287 | ||
|
|
43e3d0552e | ||
|
|
801aa2e876 | ||
|
|
bddc7a438d | ||
|
|
b8c78a68e7 | ||
|
|
49219f4447 | ||
|
|
59b1abb719 | ||
|
|
3e2cfb552b | ||
|
|
779be1b8d0 | ||
|
|
faf74de238 | ||
|
|
50a51c2e79 | ||
|
|
d31e641496 | ||
|
|
f2d36f5be9 | ||
|
|
0b55f61fac | ||
|
|
4156dcbafd | ||
|
|
36e6ac2362 | ||
|
|
9613199152 | ||
|
|
14328d7496 | ||
|
|
6af12d1acc | ||
|
|
9b44e49879 | ||
|
|
afee18f146 | ||
|
|
f007369a66 | ||
|
|
9a9c166dbe | ||
|
|
2f90e32dbf | ||
|
|
26355ccb79 | ||
|
|
27ea3c0c8e | ||
|
|
5aa35b211a | ||
|
|
92450385d2 | ||
|
|
8d15e23f3c | ||
|
|
73686d4146 | ||
|
|
0499ca1300 | ||
|
|
234c942f34 | ||
|
|
aec218ba00 | ||
|
|
b508f51fcf | ||
|
|
435628ea59 | ||
|
|
4933dbfb87 | ||
|
|
5a93c40b79 | ||
|
|
a8ec5af037 | ||
|
|
27db60ce68 | ||
|
|
195866b00d | ||
|
|
60575b6546 | ||
|
|
350b81d678 | ||
|
|
cc95314dae | ||
|
|
3f97087abb | ||
|
|
f04af2de21 | ||
|
|
e7871bf843 | ||
|
|
8e3308039a | ||
|
|
b65350b7cb | ||
|
|
069ebce895 | ||
|
|
63aa4e188e | ||
|
|
c31c9c16cf | ||
|
|
5a8a402fdc | ||
|
|
85c3e33343 | ||
|
|
1420ab31a2 | ||
|
|
fd1435537f | ||
|
|
4e0473ce11 | ||
|
|
450592b0d4 | ||
|
|
7cae0ee169 | ||
|
|
ecd0e05f79 | ||
|
|
6e3b4178ac | ||
|
|
ba18cbabfd | ||
|
|
dec757c23b | ||
|
|
0459710c9b | ||
|
|
83582ef8a3 | ||
|
|
0dc396e148 | ||
|
|
86958e1420 | ||
|
|
c5b8e629fb | ||
|
|
b0a495b4f6 | ||
|
|
7d2809467b | ||
|
|
af90eeaf37 | ||
|
|
509e513f3a | ||
|
|
80671e474c | ||
|
|
a166d859e7 | ||
|
|
6af1e0aeb7 | ||
|
|
370ffb5d7c | ||
|
|
0ba288d09e | ||
|
|
008d86983b | ||
|
|
205bdfce5c | ||
|
|
27248b197d | ||
|
|
e216b4c455 | ||
|
|
c402f53258 | ||
|
|
93329abe8b | ||
|
|
f69b3d96b6 | ||
|
|
8690a8f11a | ||
|
|
6aa2342be1 | ||
|
|
042153329b | ||
|
|
2b67091986 | ||
|
|
3da35cf0db | ||
|
|
e566484a17 | ||
|
|
e7dffbbb1e | ||
|
|
a31712ad1f | ||
|
|
2958f81adc | ||
|
|
95380fbbfb | ||
|
|
4cc6996406 | ||
|
|
372d74ec71 | ||
|
|
19ef73a07f | ||
|
|
bb3d73b87c | ||
|
|
30e9e7168f | ||
|
|
fce58f3206 | ||
|
|
b3e5ac395f | ||
|
|
3ebe9d159a | ||
|
|
ff95274757 | ||
|
|
8e653e2173 | ||
|
|
4bff17aa1a | ||
|
|
d4f300645d | ||
|
|
4ee32f02c5 | ||
|
|
2cf4440a1e | ||
|
|
644ee31654 | ||
|
|
34078d8a60 | ||
|
|
5cfae7198d | ||
|
|
6a10cda61f | ||
|
|
c149e73ef7 | ||
|
|
b11757c913 | ||
|
|
607ab35cce | ||
|
|
19ff2ebfe1 | ||
|
|
4a47dc2073 | ||
|
|
addf92d966 | ||
|
|
c987338c84 | ||
|
|
a88b0239eb | ||
|
|
caf5b1528c | ||
|
|
90f74018ae | ||
|
|
d7a253cba3 | ||
|
|
8a28846bac | ||
|
|
04545c5706 | ||
|
|
32fa81cf93 | ||
|
|
7924e4000c | ||
|
|
f9c54690b0 | ||
|
|
c3aaef3916 | ||
|
|
03dfe13769 | ||
|
|
f38b51b85a | ||
|
|
0017a6cce5 | ||
|
|
541ad624c5 | ||
|
|
7c56825f9b | ||
|
|
8a871ae643 | ||
|
|
e2191ab4b4 | ||
|
|
4264dd19a8 | ||
|
|
78f8d4ecc7 | ||
|
|
e2cc3145de | ||
|
|
710857dd41 | ||
|
|
1bfe12a288 | ||
|
|
14a88e2cfa | ||
|
|
0580130d47 | ||
|
|
a4ee82b51f | ||
|
|
1034282161 | ||
|
|
b0a8b0cc6f | ||
|
|
3f38764a0e | ||
|
|
3338c17e8f | ||
|
|
22085e5174 | ||
|
|
d7c643ee9b | ||
|
|
406284a045 | ||
|
|
50babfd471 | ||
|
|
edd36427ac | ||
|
|
9f2289329c | ||
|
|
9a1fe19cc8 | ||
|
|
09f5e2961e | ||
|
|
756ad399bf | ||
|
|
02adced7b8 | ||
|
|
9059795816 | ||
|
|
6920944724 | ||
|
|
c76b287aed | ||
|
|
5c62ec1177 | ||
|
|
09b2fdfc59 | ||
|
|
e498c9ce29 | ||
|
|
9bb4d7078e | ||
|
|
5e4d2c7760 | ||
|
|
426e84cfa3 | ||
|
|
b77df8f89f | ||
|
|
f7c946778d | ||
|
|
81599b8f43 | ||
|
|
9c0dcb2853 | ||
|
|
d3e4534673 | ||
|
|
dd81c86540 | ||
|
|
3620376c3c | ||
|
|
444e8004c7 | ||
|
|
0b0caa1142 | ||
|
|
e7233c147d | ||
|
|
004c203ef2 | ||
|
|
db04c349a7 | ||
|
|
e57a72d12b | ||
|
|
c88388da67 | ||
|
|
2ea0fa8471 | ||
|
|
7f088e58bc | ||
|
|
e992ace11c | ||
|
|
0cad6b5cbc | ||
|
|
e9a703451c | ||
|
|
03ddd51a91 | ||
|
|
9142cc4cde | ||
|
|
8e5e16ce68 | ||
|
|
d69406c4cb | ||
|
|
250e8445bb | ||
|
|
e6aafe8773 |
5
.github/FUNDING.yml
vendored
Normal file
5
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
patreon: PixelPawsAI
|
||||
ko_fi: pixelpawsai
|
||||
custom: ['paypal.me/pixelpawsai']
|
||||
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Always use English for comments.
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,2 +1,10 @@
|
||||
__pycache__/
|
||||
settings.json
|
||||
settings.json
|
||||
path_mappings.yaml
|
||||
output/*
|
||||
py/run_test.py
|
||||
.vscode/
|
||||
cache/
|
||||
civitai/
|
||||
node_modules/
|
||||
coverage/
|
||||
|
||||
19
AGENTS.md
Normal file
19
AGENTS.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
ComfyUI LoRA Manager pairs a Python backend with lightweight browser scripts. Backend modules live in `py/`, organized by responsibility: HTTP entry points under `routes/`, feature logic in `services/`, reusable helpers within `utils/`, and custom nodes in `nodes/`. Front-end widgets that extend the ComfyUI interface sit in `web/comfyui/`, while static images and templates are in `static/` and `templates/`. Shared localization files are stored in `locales/`, with workflow examples under `example_workflows/`. Tests currently reside alongside the source (`test_i18n.py`) until a dedicated `tests/` folder is introduced.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Install dependencies with `pip install -r requirements.txt` from the repo root. Launch the standalone server for iterative work via `python standalone.py --port 8188`; ComfyUI users can also load the extension directly through ComfyUI's custom node manager. Run backend checks with `python -m pytest test_i18n.py`, and target new test files explicitly (e.g. `python -m pytest tests/test_recipes.py` once added). Use `python scripts/sync_translation_keys.py` to reconcile locale keys after updating UI strings.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Follow PEP 8 with four-space indentation and descriptive snake_case module/function names, mirroring files such as `py/services/settings_manager.py`. Classes remain PascalCase, constants UPPER_SNAKE_CASE, and loggers retrieved via `logging.getLogger(__name__)`. Prefer explicit type hints for new public APIs and docstrings that clarify side effects. JavaScript in `web/comfyui/` is modern ES modules; keep imports relative, favor camelCase functions, and mirror existing file suffixes like `_widget.js` for UI components.
|
||||
|
||||
## Testing Guidelines
|
||||
Extend pytest coverage by co-locating tests near the code under test or in `tests/` with names like `test_<feature>.py`. When introducing new routes or services, add regression cases that mock ComfyUI dependencies (see the standalone mocking helpers in `standalone.py`). Prioritize deterministic fixtures for filesystem interactions and ensure translations include coverage when adding new locale keys. Always run `python -m pytest` before submitting work.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Commits follow the conventional pattern seen in `git log` (`feat(scope):`, `fix(scope):`, `chore(scope):`). Keep messages imperative and scoped to a single change. Pull requests should summarize the problem, detail the solution, list manual test evidence, and link any GitHub issues. Include UI screenshots or GIFs when front-end behavior changes, and call out migration steps (e.g., settings updates) in the PR description.
|
||||
|
||||
## Configuration & Localization Tips
|
||||
Sample configuration defaults live in `settings.json.example`; copy it to `settings.json` and adjust model directories before running the standalone server. Whenever you add UI text, update `locales/<lang>.json` and run the translation sync script. Store reference assets in `civitai/` or `docs/` rather than mixing them with production templates, keeping the runtime folders (`static/`, `templates/`) deploy-ready.
|
||||
687
LICENSE
687
LICENSE
@@ -1,21 +1,674 @@
|
||||
MIT License
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (c) 2023 Will Miao
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
Preamble
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
ComfyUI Lora Manager - A ComfyUI custom node for managing models
|
||||
Copyright (C) 2025 Will Miao
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
ComfyUI Lora Manager Copyright (C) 2025 Will Miao
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
232
README.md
232
README.md
@@ -1,56 +1,86 @@
|
||||
# ComfyUI LoRA Manager
|
||||
|
||||
A web-based management interface designed to help you organize and manage your local LoRA models in ComfyUI. Access the interface at: `http://localhost:8188/loras`
|
||||
> **Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!**
|
||||
|
||||

|
||||
[](https://discord.gg/vcqNrWVFvM)
|
||||
[](https://github.com/willmiao/ComfyUI-Lora-Manager/releases)
|
||||
[](https://github.com/willmiao/ComfyUI-Lora-Manager/releases)
|
||||
|
||||
A comprehensive toolset that streamlines organizing, downloading, and applying LoRA models in ComfyUI. With powerful features like recipe management, checkpoint organization, and one-click workflow integration, working with models becomes faster, smoother, and significantly easier. Access the interface at: `http://localhost:8188/loras`
|
||||
|
||||

|
||||
|
||||
## 📺 Tutorial: One-Click LoRA Integration
|
||||
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
|
||||
|
||||
[](https://youtu.be/qS95OjX3e70)
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
|
||||
## 🌐 Browser Extension
|
||||
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
|
||||
|
||||

|
||||
|
||||
<div>
|
||||
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
|
||||
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div>
|
||||
|
||||
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
|
||||
|
||||
---
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.7.36
|
||||
* Enhanced LoRA details view with model descriptions and tags display
|
||||
* Added tag filtering system for improved model discovery
|
||||
* Implemented editable trigger words functionality
|
||||
* Improved TriggerWord Toggle node with new group mode option for granular control
|
||||
* Added new Lora Stacker node with cross-compatibility support (works with efficiency nodes, ComfyRoll, easy-use, etc.)
|
||||
* Fixed several bugs
|
||||
### v0.9.3
|
||||
* **Metadata Archive Database Support** - Added the ability to download and utilize a metadata archive database, enabling access to metadata for models that have been deleted from CivitAI.
|
||||
* **App-Level Proxy Settings** - Introduced support for configuring a global proxy within the application, making it easier to use the manager behind network restrictions.
|
||||
* **Bug Fixes** - Various bug fixes for improved stability and reliability.
|
||||
|
||||
### v0.7.35-beta
|
||||
* Added base model filtering
|
||||
* Implemented bulk operations (copy syntax, move multiple LoRAs)
|
||||
* Added ability to edit LoRA model names in details view
|
||||
* Added update checker with notification system
|
||||
* Added support modal for user feedback and community links
|
||||
### v0.9.2
|
||||
* **Bulk Auto-Organization Action** - Added a new bulk auto-organization feature. You can now select multiple models and automatically organize them according to your current path template settings for streamlined management.
|
||||
* **Bug Fixes** - Addressed several bugs to improve stability and reliability.
|
||||
|
||||
### v0.7.33
|
||||
* Enhanced LoRA Loader node with visual strength adjustment widgets
|
||||
* Added toggle switches for LoRA enable/disable
|
||||
* Implemented image tooltips for LoRA preview
|
||||
* Added TriggerWord Toggle node with visual word selection
|
||||
* Fixed various bugs and improved stability
|
||||
### v0.9.1
|
||||
* **Enhanced Bulk Operations** - Improved bulk operations with Marquee Selection and a bulk operation context menu, providing a more intuitive, desktop-application-like user experience.
|
||||
* **New Bulk Actions** - Added bulk operations for adding tags and setting base models to multiple models simultaneously.
|
||||
|
||||
### v0.7.3
|
||||
* Added "Lora Loader (LoraManager)" custom node for workflows
|
||||
* Implemented one-click LoRA integration
|
||||
* Added direct copying of LoRA syntax from manager interface
|
||||
* Added automatic preset strength value application
|
||||
* Added automatic trigger word loading
|
||||
### v0.9.0
|
||||
* **UI Overhaul for Enhanced Navigation** - Replaced the top flat folder tags with a new folder sidebar and breadcrumb navigation system for a more intuitive folder browsing and selection experience.
|
||||
* **Dual-Mode Folder Sidebar** - The new folder sidebar offers two display modes: 'List Mode,' which mirrors the classic folder view, and 'Tree Mode,' which presents a hierarchical folder structure for effortless navigation through nested directories.
|
||||
* **Internationalization Support** - Introduced multi-language support, now available in English, Simplified Chinese, Traditional Chinese, Spanish, Japanese, Korean, French, Russian, and German. Feedback from native speakers is welcome to improve the translations.
|
||||
* **Automatic Filename Conflict Resolution** - Implemented automatic file renaming (`original name + short hash`) to prevent conflicts when downloading or moving models.
|
||||
* **Performance Optimizations & Bug Fixes** - Various performance improvements and bug fixes for a more stable and responsive experience.
|
||||
|
||||
### v0.7.0
|
||||
* Added direct CivitAI integration for downloading LoRAs
|
||||
* Implemented version selection for model downloads
|
||||
* Added target folder selection for downloads
|
||||
* Added context menu with quick actions
|
||||
* Added force refresh for CivitAI data
|
||||
* Implemented LoRA movement between folders
|
||||
* Added personal usage tips and notes for LoRAs
|
||||
* Improved performance for details window
|
||||
### v0.8.30
|
||||
* **Automatic Model Path Correction** - Added auto-correction for model paths in built-in nodes such as Load Checkpoint, Load Diffusion Model, Load LoRA, and other custom nodes with similar functionality. Workflows containing outdated or incorrect model paths will now be automatically updated to reflect the current location of your models.
|
||||
* **Node UI Enhancements** - Improved node interface for a smoother and more intuitive user experience.
|
||||
* **Bug Fixes** - Addressed various bugs to enhance stability and reliability.
|
||||
|
||||
### v0.8.29
|
||||
* **Enhanced Recipe Imports** - Improved recipe importing with new target folder selection, featuring path input autocomplete and interactive folder tree navigation. Added a "Use Default Path" option when downloading missing LoRAs.
|
||||
* **WanVideo Lora Select Node Update** - Updated the WanVideo Lora Select node with a 'merge_loras' option to match the counterpart node in the WanVideoWrapper node package.
|
||||
* **Autocomplete Conflict Resolution** - Resolved an autocomplete feature conflict in LoRA nodes with pysssss autocomplete.
|
||||
* **Improved Download Functionality** - Enhanced download functionality with resumable downloads and improved error handling.
|
||||
* **Bug Fixes** - Addressed several bugs for improved stability and performance.
|
||||
|
||||
### v0.8.28
|
||||
* **Autocomplete for Node Inputs** - Instantly find and add LoRAs by filename directly in Lora Loader, Lora Stacker, and WanVideo Lora Select nodes. Autocomplete suggestions include preview tooltips and preset weights, allowing you to quickly select LoRAs without opening the LoRA Manager UI.
|
||||
* **Duplicate Notification Control** - Added a switch to duplicates mode, enabling users to turn off duplicate model notifications for a more streamlined experience.
|
||||
* **Download Example Images from Context Menu** - Introduced a new context menu option to download example images for individual models.
|
||||
|
||||
### v0.8.27
|
||||
* **User Experience Enhancements** - Improved the model download target folder selection with path input autocomplete and interactive folder tree navigation, making it easier and faster to choose where models are saved.
|
||||
* **Default Path Option for Downloads** - Added a "Use Default Path" option when downloading models. When enabled, models are automatically organized and stored according to your configured path template settings.
|
||||
* **Advanced Download Path Templates** - Expanded path template settings, allowing users to set individual templates for LoRA, checkpoint, and embedding models for greater flexibility. Introduced the `{author}` placeholder, enabling automatic organization of model files by creator name.
|
||||
* **Bug Fixes & Stability Improvements** - Addressed various bugs and improved overall stability for a smoother experience.
|
||||
|
||||
### v0.8.26
|
||||
* **Creator Search Option** - Added ability to search models by creator name, making it easier to find models from specific authors.
|
||||
* **Enhanced Node Usability** - Improved user experience for Lora Loader, Lora Stacker, and WanVideo Lora Select nodes by fixing the maximum height of the text input area. Users can now freely and conveniently adjust the LoRA region within these nodes.
|
||||
* **Compatibility Fixes** - Resolved compatibility issues with ComfyUI and certain custom nodes, including ComfyUI-Custom-Scripts, ensuring smoother integration and operation.
|
||||
|
||||
[View Update History](./update_logs.md)
|
||||
|
||||
@@ -69,13 +99,6 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
- 🚀 **High Performance**
|
||||
- Fast model loading and browsing
|
||||
- Smooth scrolling through large collections
|
||||
- Real-time updates when files change
|
||||
|
||||
- 📂 **Advanced Organization**
|
||||
- Quick search with fuzzy matching
|
||||
- Folder-based categorization
|
||||
- Move LoRAs between folders
|
||||
- Sort by name or date
|
||||
|
||||
- 🌐 **Rich Model Integration**
|
||||
- Direct download from CivitAI
|
||||
@@ -84,29 +107,50 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
- Trigger words at a glance
|
||||
- One-click workflow integration with preset values
|
||||
|
||||
- 🔄 **Checkpoint Management**
|
||||
- Scan and organize checkpoint models
|
||||
- Filter and search your collection
|
||||
- View and edit metadata
|
||||
- Clean up and manage disk space
|
||||
|
||||
- 🧩 **LoRA Recipes**
|
||||
- Save and share favorite LoRA combinations
|
||||
- Preserve generation parameters for future reference
|
||||
- Quick application to workflows
|
||||
- Import/export functionality for community sharing
|
||||
|
||||
- 💻 **User Friendly**
|
||||
- One-click access from ComfyUI menu
|
||||
- Context menu for quick actions
|
||||
- Custom notes and usage tips
|
||||
- Multi-folder support
|
||||
- Visual progress indicators during initialization
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
### Option 1: **ComfyUI Manager** (Recommended)
|
||||
### Option 1: **ComfyUI Manager** (Recommended for ComfyUI users)
|
||||
|
||||
1. Open **ComfyUI**.
|
||||
2. Go to **Manager > Custom Node Manager**.
|
||||
3. Search for `lora-manager`.
|
||||
4. Click **Install**.
|
||||
|
||||
### Option 2: **Manual Installation**
|
||||
### Option 2: **Portable Standalone Edition** (No ComfyUI required)
|
||||
|
||||
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.9.2/lora_manager_portable.7z)
|
||||
2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder
|
||||
3. Edit `settings.json` to include your correct model folder paths and CivitAI API key
|
||||
4. Run run.bat
|
||||
- To change the startup port, edit `run.bat` and modify the parameter (e.g. `--port 9001`)
|
||||
|
||||
### Option 3: **Manual Installation**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/willmiao/ComfyUI-Lora-Manager.git
|
||||
cd ComfyUI-Lora-Manager
|
||||
pip install requirements.txt
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -127,11 +171,89 @@ pip install requirements.txt
|
||||
- Paste into the Lora Loader node's text input
|
||||
- The node will automatically apply preset strength and trigger words
|
||||
|
||||
### Filename Format Patterns for Save Image Node
|
||||
|
||||
The Save Image Node supports dynamic filename generation using pattern codes. You can customize how your images are named using the following format patterns:
|
||||
|
||||
#### Available Pattern Codes
|
||||
|
||||
- `%seed%` - Inserts the generation seed number
|
||||
- `%width%` - Inserts the image width
|
||||
- `%height%` - Inserts the image height
|
||||
- `%pprompt:N%` - Inserts the positive prompt (limited to N characters)
|
||||
- `%nprompt:N%` - Inserts the negative prompt (limited to N characters)
|
||||
- `%model:N%` - Inserts the model/checkpoint name (limited to N characters)
|
||||
- `%date%` - Inserts current date/time as "yyyyMMddhhmmss"
|
||||
- `%date:FORMAT%` - Inserts date using custom format with:
|
||||
- `yyyy` - 4-digit year
|
||||
- `yy` - 2-digit year
|
||||
- `MM` - 2-digit month
|
||||
- `dd` - 2-digit day
|
||||
- `hh` - 2-digit hour
|
||||
- `mm` - 2-digit minute
|
||||
- `ss` - 2-digit second
|
||||
|
||||
#### Examples
|
||||
|
||||
- `image_%seed%` → `image_1234567890`
|
||||
- `gen_%width%x%height%` → `gen_512x768`
|
||||
- `%model:10%_%seed%` → `dreamshape_1234567890`
|
||||
- `%date:yyyy-MM-dd%` → `2025-04-28`
|
||||
- `%pprompt:20%_%seed%` → `beautiful landscape_1234567890`
|
||||
- `%model%_%date:yyMMdd%_%seed%` → `dreamshaper_v8_250428_1234567890`
|
||||
|
||||
You can combine multiple patterns to create detailed, organized filenames for your generated images.
|
||||
|
||||
### Standalone Mode
|
||||
|
||||
You can now run LoRA Manager independently from ComfyUI:
|
||||
|
||||
1. **For ComfyUI users**:
|
||||
- Launch ComfyUI with LoRA Manager at least once to initialize the necessary path information in the `settings.json` file.
|
||||
- Make sure dependencies are installed: `pip install -r requirements.txt`
|
||||
- From your ComfyUI root directory, run:
|
||||
```bash
|
||||
python custom_nodes\comfyui-lora-manager\standalone.py
|
||||
```
|
||||
- Access the interface at: `http://localhost:8188/loras`
|
||||
- You can specify a different host or port with arguments:
|
||||
```bash
|
||||
python custom_nodes\comfyui-lora-manager\standalone.py --host 127.0.0.1 --port 9000
|
||||
```
|
||||
|
||||
2. **For non-ComfyUI users**:
|
||||
- Copy the provided `settings.json.example` file to create a new file named `settings.json`
|
||||
- Edit `settings.json` to include your correct model folder paths and CivitAI API key
|
||||
- Install required dependencies: `pip install -r requirements.txt`
|
||||
- Run standalone mode:
|
||||
```bash
|
||||
python standalone.py
|
||||
```
|
||||
- Access the interface through your browser at: `http://localhost:8188/loras`
|
||||
|
||||
This standalone mode provides a lightweight option for managing your model and recipe collection without needing to run the full ComfyUI environment, making it useful even for users who primarily use other stable diffusion interfaces.
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have suggestions, bug reports, or improvements, feel free to open an issue or contribute directly to the codebase. Pull requests are always welcome!
|
||||
Thank you for your interest in contributing to ComfyUI LoRA Manager! As this project is currently in its early stages and undergoing rapid development and refactoring, we are temporarily not accepting pull requests.
|
||||
|
||||
However, your feedback and ideas are extremely valuable to us:
|
||||
- Please feel free to open issues for any bugs you encounter
|
||||
- Submit feature requests through GitHub issues
|
||||
- Share your suggestions for improvements
|
||||
|
||||
We appreciate your understanding and look forward to potentially accepting code contributions once the project architecture stabilizes.
|
||||
|
||||
---
|
||||
|
||||
## Credits
|
||||
|
||||
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
||||
|
||||
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
||||
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
||||
|
||||
---
|
||||
|
||||
@@ -141,18 +263,16 @@ If you find this project helpful, consider supporting its development:
|
||||
|
||||
[](https://ko-fi.com/pixelpawsai)
|
||||
|
||||
[](https://patreon.com/PixelPawsAI)
|
||||
|
||||
WeChat: [Click to view QR code](https://raw.githubusercontent.com/willmiao/ComfyUI-Lora-Manager/main/static/images/wechat-qr.webp)
|
||||
|
||||
## 💬 Community
|
||||
|
||||
Join our Discord community for support, discussions, and updates:
|
||||
[Discord Server](https://discord.gg/vcqNrWVFvM)
|
||||
|
||||
---
|
||||
## Star History
|
||||
|
||||
## 🗺️ Roadmap
|
||||
|
||||
- ✅ One-click integration of LoRAs into ComfyUI workflows with preset strength values
|
||||
- 🤝 Improved usage tips retrieval from CivitAI model pages
|
||||
- 🔌 Integration with Power LoRA Loader and other management tools
|
||||
- 🛡️ Configurable NSFW level settings for content filtering
|
||||
|
||||
---
|
||||
[](https://star-history.com/#willmiao/ComfyUI-Lora-Manager&Date)
|
||||
|
||||
45
__init__.py
45
__init__.py
@@ -1,16 +1,49 @@
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
try: # pragma: no cover - import fallback for pytest collection
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
from .py.nodes.debug_metadata import DebugMetadata
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
except ImportError: # pragma: no cover - allows running under pytest without package install
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
package_root = pathlib.Path(__file__).resolve().parent
|
||||
if str(package_root) not in sys.path:
|
||||
sys.path.append(str(package_root))
|
||||
|
||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
|
||||
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
|
||||
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
|
||||
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
||||
SaveImage = importlib.import_module("py.nodes.save_image").SaveImage
|
||||
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
||||
WanVideoLoraSelect = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelect
|
||||
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
|
||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
LoraManagerLoader.NAME: LoraManagerLoader,
|
||||
LoraManagerTextLoader.NAME: LoraManagerTextLoader,
|
||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||
LoraStacker.NAME: LoraStacker
|
||||
LoraStacker.NAME: LoraStacker,
|
||||
SaveImage.NAME: SaveImage,
|
||||
DebugMetadata.NAME: DebugMetadata,
|
||||
WanVideoLoraSelect.NAME: WanVideoLoraSelect,
|
||||
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
|
||||
# Initialize metadata collector
|
||||
init_metadata_collector()
|
||||
|
||||
# Register routes on import
|
||||
LoraManager.add_routes()
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']
|
||||
|
||||
180
docs/LM-Extension-Wiki.md
Normal file
180
docs/LM-Extension-Wiki.md
Normal file
@@ -0,0 +1,180 @@
|
||||
## Overview
|
||||
|
||||
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com).
|
||||
It also supports browsing on [CivArchive](https://civarchive.com/) (formerly CivitaiArchive).
|
||||
|
||||
With this extension, you can:
|
||||
|
||||
✅ Instantly see which models are already present in your local library
|
||||
✅ Download new models with a single click
|
||||
✅ Manage downloads efficiently with queue and parallel download support
|
||||
✅ Keep your downloaded models automatically organized according to your custom settings
|
||||
|
||||

|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why Are All Features for Supporters Only?
|
||||
|
||||
I love building tools for the Stable Diffusion and ComfyUI communities, and LoRA Manager is a passion project that I've poured countless hours into. When I created this companion extension, my hope was to offer its core features for free, as a thank-you to all of you.
|
||||
|
||||
Unfortunately, I've reached a point where I need to be realistic. The level of support from the free model has been far lower than what's needed to justify the continuous development and maintenance for both projects. It was a difficult decision, but I've chosen to make the extension's features exclusive to supporters.
|
||||
|
||||
This change is crucial for me to be able to continue dedicating my time to improving the free and open-source LoRA Manager, which I'm committed to keeping available for everyone.
|
||||
|
||||
Your support does more than just unlock a few features—it allows me to keep innovating and ensures the core LoRA Manager project thrives. I'm incredibly grateful for your understanding and any support you can offer. ❤️
|
||||
|
||||
(_For those who previously supported me on Ko-fi with a one-time donation, I'll be sending out license keys individually as a thank-you._)
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
### Supported Browsers & Installation Methods
|
||||
|
||||
| Browser | Installation Method |
|
||||
|--------------------|-------------------------------------------------------------------------------------|
|
||||
| **Google Chrome** | [Chrome Web Store link](https://chromewebstore.google.com/detail/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) |
|
||||
| **Microsoft Edge** | Install via Chrome Web Store (compatible) |
|
||||
| **Brave Browser** | Install via Chrome Web Store (compatible) |
|
||||
| **Opera** | Install via Chrome Web Store (compatible) |
|
||||
| **Firefox** | <div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div> |
|
||||
|
||||
For non-Chrome browsers (e.g., Microsoft Edge), you can typically install extensions from the Chrome Web Store by following these steps: open the extension’s Chrome Web Store page, click 'Get extension', then click 'Allow' when prompted to enable installations from other stores, and finally click 'Add extension' to complete the installation.
|
||||
|
||||
---
|
||||
|
||||
## Privacy & Security
|
||||
|
||||
I understand concerns around browser extensions and privacy, and I want to be fully transparent about how the **LM Civitai Extension** works:
|
||||
|
||||
- **Reviewed and Verified**
|
||||
This extension has been **manually reviewed and approved by the Chrome Web Store**. The Firefox version uses the **exact same code** (only the packaging format differs) and has passed **Mozilla’s Add-on review**.
|
||||
|
||||
- **Minimal Network Access**
|
||||
The only external server this extension connects to is:
|
||||
**`https://willmiao.shop`** — used solely for **license validation**.
|
||||
|
||||
It does **not collect, transmit, or store any personal or usage data**.
|
||||
No browsing history, no user IDs, no analytics, no hidden trackers.
|
||||
|
||||
- **Local-Only Model Detection**
|
||||
Model detection and LoRA Manager communication all happen **locally** within your browser, directly interacting with your local LoRA Manager backend.
|
||||
|
||||
I value your trust and are committed to keeping your local setup private and secure. If you have any questions, feel free to reach out!
|
||||
|
||||
---
|
||||
|
||||
## How to Use
|
||||
|
||||
After installing the extension, you'll automatically receive a **7-day trial** to explore all features.
|
||||
|
||||
When the extension is correctly installed and your license is valid:
|
||||
|
||||
- Open **Civitai**, and you'll see visual indicators added by the extension on model cards, showing:
|
||||
- ✅ Models already present in your local library
|
||||
- ⬇️ A download button for models not in your library
|
||||
|
||||
Clicking the download button adds the corresponding model version to the download queue, waiting to be downloaded. You can set up to **5 models to download simultaneously**.
|
||||
|
||||
### Visual Indicators Appear On:
|
||||
|
||||
- **Home Page** — Featured models
|
||||
- **Models Page**
|
||||
- **Creator Profiles** — If the creator has set their models to be visible
|
||||
- **Recommended Resources** — On individual model pages
|
||||
|
||||
### Version Buttons on Model Pages
|
||||
|
||||
On a specific model page, visual indicators also appear on version buttons, showing which versions are already in your local library.
|
||||
|
||||
When switching to a specific version by clicking a version button:
|
||||
|
||||
- Clicking the download button will open a dropdown:
|
||||
- Download via **LoRA Manager**
|
||||
- Download via **Original Download** (browser download)
|
||||
|
||||
You can check **Remember my choice** to set your preferred default. You can change this setting anytime in the extension's settings.
|
||||
|
||||

|
||||
|
||||
### Resources on Image Pages (2025-08-05) — now shows in-library indicators for image resources. ‘Import image as recipe’ coming soon!
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Model Download Location & LoRA Manager Settings
|
||||
|
||||
To use the **one-click download function**, you must first set:
|
||||
|
||||
- Your **Default LoRAs Root**
|
||||
- Your **Default Checkpoints Root**
|
||||
|
||||
These are set within LoRA Manager's settings.
|
||||
|
||||
When everything is configured, downloaded model files will be placed in:
|
||||
|
||||
`<Default_Models_Root>/<Base_Model_of_the_Model>/<First_Tag_of_the_Model>`
|
||||
|
||||
|
||||
### Update: Default Path Customization (2025-07-21)
|
||||
|
||||
A new setting to customize the default download path has been added in the nightly version. You can now personalize where models are saved when downloading via the LM Civitai Extension.
|
||||
|
||||

|
||||
|
||||
The previous YAML path mapping file will be deprecated—settings will now be unified in settings.json to simplify configuration.
|
||||
|
||||
---
|
||||
|
||||
## Backend Port Configuration
|
||||
|
||||
If your **ComfyUI** or **LoRA Manager** backend is running on a port **other than the default 8188**, you must configure the backend port in the extension's settings.
|
||||
|
||||
After correctly setting and saving the port, you'll see in the extension's header area:
|
||||
- A **Healthy** status with the tooltip: `Connected to LoRA Manager on port xxxx`
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Connecting to a Remote LoRA Manager
|
||||
|
||||
If your LoRA Manager is running on another computer, you can still connect from your browser using port forwarding.
|
||||
|
||||
> **Why can't you set a remote IP directly?**
|
||||
>
|
||||
> For privacy and security, the extension only requests access to `http://127.0.0.1/*`. Supporting remote IPs would require much broader permissions, which may be rejected by browser stores and could raise user concerns.
|
||||
|
||||
**Solution: Port Forwarding with `socat`**
|
||||
|
||||
On your browser computer, run:
|
||||
|
||||
`socat TCP-LISTEN:8188,bind=127.0.0.1,fork TCP:REMOTE.IP.ADDRESS.HERE:8188`
|
||||
|
||||
- Replace `REMOTE.IP.ADDRESS.HERE` with the IP of the machine running LoRA Manager.
|
||||
- Adjust the port if needed.
|
||||
|
||||
This lets the extension connect to `127.0.0.1:8188` as usual, with traffic forwarded to your remote server.
|
||||
|
||||
_Thanks to user **Temikus** for sharing this solution!_
|
||||
|
||||
---
|
||||
|
||||
## Roadmap
|
||||
|
||||
The extension will evolve alongside **LoRA Manager** improvements. Planned features include:
|
||||
|
||||
- [x] Support for **additional model types** (e.g., embeddings)
|
||||
- [ ] One-click **Recipe Import**
|
||||
- [x] Display of in-library status for all resources in the **Resources Used** section of the image page
|
||||
- [x] One-click **Auto-organize Models**
|
||||
|
||||
**Stay tuned — and thank you for your support!**
|
||||
|
||||
---
|
||||
|
||||
93
docs/architecture/example_images_routes.md
Normal file
93
docs/architecture/example_images_routes.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# Example image route architecture
|
||||
|
||||
The example image routing stack mirrors the layered model route stack described in
|
||||
[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup,
|
||||
handler orchestration, and long-running workflows now live in clearly separated modules so
|
||||
we can extend download/import behaviour without touching the entire feature surface.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ExampleImagesHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Download manager / processor / file manager]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Filesystem]
|
||||
F1 --> G2[Model metadata]
|
||||
F1 --> G3[WebSocket progress]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. |
|
||||
| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. |
|
||||
| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. |
|
||||
| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. |
|
||||
| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}`
|
||||
mapping consumed by the registrar. The table below outlines how each handler collaborates
|
||||
with the use cases and utilities.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. |
|
||||
| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. |
|
||||
| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. |
|
||||
| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Shared progress snapshots** - The download handler returns the same snapshot built by
|
||||
`DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket
|
||||
progress events.
|
||||
* **Safe filesystem access** - All folder/file actions flow through
|
||||
`ExampleImagesFileManager`, which validates the configured example image root and ensures
|
||||
responses never leak absolute paths outside the allowed directory.
|
||||
* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`,
|
||||
which updates model metadata via `MetadataManager` and notifies the relevant scanners so
|
||||
cache state stays in sync.
|
||||
|
||||
## Migration notes
|
||||
|
||||
The refactor brings the example image stack in line with the model/recipe stacks:
|
||||
|
||||
1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects
|
||||
should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring
|
||||
handler callables.
|
||||
2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like
|
||||
`BaseModelRoutes`. If you previously instantiated handlers directly, inject custom
|
||||
collaborators via the controller constructor (`download_manager`, `processor`,
|
||||
`file_manager`) to keep test seams predictable.
|
||||
3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching
|
||||
`DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers
|
||||
expect those abstractions to surface validation/concurrency errors, and bypassing them
|
||||
will skip the HTTP-friendly error mapping.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`.
|
||||
2. Expose the coroutine on an existing handler class (or create a new handler and extend
|
||||
`ExampleImagesHandlerSet`).
|
||||
3. Wire additional services or factories inside `_build_handler_set` on
|
||||
`ExampleImagesRoutes`, mirroring how the model stack introduces new use cases.
|
||||
|
||||
`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause
|
||||
flows, and import validations. Use it as a template when introducing new handler
|
||||
collaborators or error mappings.
|
||||
100
docs/architecture/model_routes.md
Normal file
100
docs/architecture/model_routes.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# Base model route architecture
|
||||
|
||||
The model routing stack now splits HTTP wiring, orchestration logic, and
|
||||
business rules into discrete layers. The goal is to make it obvious where a
|
||||
new collaborator should live and which contract it must honour. The diagram
|
||||
below captures the end-to-end flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ModelHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & metadata]
|
||||
F1 --> G2[Filesystem]
|
||||
F1 --> G3[WebSocket state]
|
||||
end
|
||||
```
|
||||
|
||||
Every box maps to a concrete module:
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. |
|
||||
| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. |
|
||||
| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). |
|
||||
| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. |
|
||||
| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. |
|
||||
|
||||
## Handler responsibilities & contracts
|
||||
|
||||
`ModelHandlerSet` flattens the handler objects into the exact callables used by
|
||||
the registrar. The table below highlights the separation of concerns within
|
||||
the set and the invariants that must hold after each handler returns.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
||||
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
||||
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
||||
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
||||
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
||||
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
||||
| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. |
|
||||
| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
Each use case exposes a narrow asynchronous API that hides the underlying
|
||||
services. Their error mapping is essential for predictable HTTP responses.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. |
|
||||
| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. |
|
||||
| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. |
|
||||
|
||||
## Maintaining legacy contracts
|
||||
|
||||
The refactor preserves the invariants called out in the previous architecture
|
||||
notes. The most critical ones are reiterated here to emphasise the
|
||||
collaboration points:
|
||||
|
||||
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
||||
channelled through `ModelManagementHandler`. The handler delegates to
|
||||
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
|
||||
mutated in-place before the handler returns. The accompanying tests assert
|
||||
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
|
||||
each mutation.
|
||||
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
||||
asset, `MetadataSyncService` persists the JSON metadata, and
|
||||
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
||||
the static URL produced by `config.get_preview_static_url`, keeping browser
|
||||
clients in lockstep with disk state.
|
||||
3. **Download progress** – `DownloadCoordinator.schedule_download` generates the
|
||||
download identifier, registers a WebSocket progress callback, and caches the
|
||||
latest numeric progress via `WebSocketManager`. Both `download_model`
|
||||
responses and `/download-progress/{id}` polling read from the same cache to
|
||||
guarantee consistent progress reporting across transports.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
To add a new shared route:
|
||||
|
||||
1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name.
|
||||
2. Implement the corresponding coroutine on one of the handlers inside
|
||||
`ModelHandlerSet` (or introduce a new handler class when the concern does not
|
||||
fit existing ones).
|
||||
3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by
|
||||
wiring services or use cases through the constructor parameters.
|
||||
|
||||
Model-specific routes should continue to be registered inside the subclass
|
||||
implementation of `setup_specific_routes`, reusing the shared registrar where
|
||||
possible.
|
||||
89
docs/architecture/recipe_routes.md
Normal file
89
docs/architecture/recipe_routes.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Recipe route architecture
|
||||
|
||||
The recipe routing stack now mirrors the modular model route design. HTTP
|
||||
bindings, controller wiring, handler orchestration, and business rules live in
|
||||
separate layers so new behaviours can be added without re-threading the entire
|
||||
feature. The diagram below outlines the flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[RecipeHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & fingerprint index]
|
||||
F1 --> G2[Metadata files]
|
||||
F1 --> G3[Temporary shares]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. |
|
||||
| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. |
|
||||
| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. |
|
||||
| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`RecipeHandlerSet` flattens purpose-built handler objects into the callables the
|
||||
registrar binds. Each handler is responsible for a narrow concern and enforces a
|
||||
set of invariants before returning:
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. |
|
||||
| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. |
|
||||
| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. |
|
||||
| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. |
|
||||
| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. |
|
||||
| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
The dedicated services encapsulate long-running work so handlers stay thin.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. |
|
||||
| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. |
|
||||
| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call
|
||||
back into the recipe scanner to mutate the in-memory cache and fingerprint
|
||||
index before returning a response. Tests assert that these methods are invoked
|
||||
even when stubbing persistence.
|
||||
* **Fingerprint management** – `RecipePersistenceService` recomputes
|
||||
fingerprints whenever LoRA metadata changes and duplicate lookups use those
|
||||
fingerprints to group recipes. Handlers bubble the resulting IDs so clients
|
||||
can merge duplicates without an extra fetch.
|
||||
* **Metadata synchronisation** – Saving or reconnecting a recipe updates the
|
||||
JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the
|
||||
scanner to resort its cache. Sharing relies on this metadata to generate
|
||||
filenames and ensure downloads stay in sync with on-disk state.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name.
|
||||
2. Implement the coroutine on an existing handler or introduce a new handler
|
||||
class inside `py/routes/handlers/recipe_handlers.py` when the concern does
|
||||
not fit existing ones.
|
||||
3. Wire additional collaborators inside
|
||||
`BaseRecipeRoutes._create_handler_set` (inject new services or factories) and
|
||||
expose helper getters on the handler owner if the handler needs to share
|
||||
utilities.
|
||||
|
||||
Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing,
|
||||
mutation, analysis-error, and sharing paths end-to-end, ensuring the controller
|
||||
and handler wiring remains valid as new capabilities are added.
|
||||
|
||||
23
docs/frontend-testing-roadmap.md
Normal file
23
docs/frontend-testing-roadmap.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Frontend Automation Testing Roadmap
|
||||
|
||||
This roadmap tracks the planned rollout of automated testing for the ComfyUI LoRA Manager frontend. Each phase builds on the infrastructure introduced in this change set and records progress so future contributors can quickly identify the next tasks.
|
||||
|
||||
## Phase Overview
|
||||
|
||||
| Phase | Goal | Primary Focus | Status | Notes |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| Phase 0 | Establish baseline tooling | Add Node test runner, jsdom environment, and seed smoke tests | ✅ Complete | Vitest + jsdom configured, example state tests committed |
|
||||
| Phase 1 | Cover state management logic | Unit test selectors, derived data helpers, and storage utilities under `static/js/state` and `static/js/utils` | ✅ Complete | Storage helpers and state selectors now exercised via deterministic suites |
|
||||
| Phase 2 | Test AppCore orchestration | Simulate page bootstrapping, infinite scroll hooks, and manager registration using JSDOM DOM fixtures | 🟡 In Progress | AppCore initialization specs landed; expand to additional page wiring and scroll hooks |
|
||||
| Phase 3 | Validate page-specific managers | Add focused suites for `loras`, `checkpoints`, `embeddings`, and `recipes` managers covering filtering, sorting, and bulk actions | ⚪ Not Started | Consider shared helpers for mocking API modules and storage |
|
||||
| Phase 4 | Interaction-level regression tests | Exercise template fragments, modals, and menus to ensure UI wiring remains intact | ⚪ Not Started | Evaluate Playwright component testing or happy-path DOM snapshots |
|
||||
| Phase 5 | Continuous integration & coverage | Integrate frontend tests into CI workflow and track coverage metrics | ⚪ Not Started | Align reporting directories with backend coverage for unified reporting |
|
||||
|
||||
## Next Steps Checklist
|
||||
|
||||
- [x] Expand unit tests for `storageHelpers` covering migrations and namespace behavior.
|
||||
- [ ] Document DOM fixture strategy for reproducing template structures in tests.
|
||||
- [x] Prototype AppCore initialization test that verifies manager bootstrapping with stubbed dependencies.
|
||||
- [ ] Evaluate integrating coverage reporting once test surface grows (> 20 specs).
|
||||
|
||||
Maintaining this roadmap alongside code changes will make it easier to append new automated test tasks and update their progress.
|
||||
BIN
example_workflows/Flux Example.jpg
Normal file
BIN
example_workflows/Flux Example.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 669 KiB |
1
example_workflows/Flux Example.json
Normal file
1
example_workflows/Flux Example.json
Normal file
File diff suppressed because one or more lines are too long
BIN
example_workflows/Illustrious Pony Example.jpg
Normal file
BIN
example_workflows/Illustrious Pony Example.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 669 KiB |
1
example_workflows/Illustrious Pony Example.json
Normal file
1
example_workflows/Illustrious Pony Example.json
Normal file
File diff suppressed because one or more lines are too long
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1242
locales/de.json
Normal file
1242
locales/de.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/en.json
Normal file
1242
locales/en.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/es.json
Normal file
1242
locales/es.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/fr.json
Normal file
1242
locales/fr.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/ja.json
Normal file
1242
locales/ja.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/ko.json
Normal file
1242
locales/ko.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/ru.json
Normal file
1242
locales/ru.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/zh-CN.json
Normal file
1242
locales/zh-CN.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/zh-TW.json
Normal file
1242
locales/zh-TW.json
Normal file
File diff suppressed because it is too large
Load Diff
2572
package-lock.json
generated
Normal file
2572
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
14
package.json
Normal file
14
package.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "comfyui-lora-manager-frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"jsdom": "^24.0.0",
|
||||
"vitest": "^1.6.0"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Project namespace package."""
|
||||
|
||||
# pytest's internal compatibility layer still imports ``py.path.local`` from the
|
||||
# historical ``py`` dependency. Because this project reuses the ``py`` package
|
||||
# name, we expose a minimal shim so ``py.path.local`` resolves to ``pathlib.Path``
|
||||
# during test runs without pulling in the external dependency.
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
path = SimpleNamespace(local=Path)
|
||||
|
||||
__all__ = ["path"]
|
||||
|
||||
243
py/config.py
243
py/config.py
@@ -3,6 +3,11 @@ import platform
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
import logging
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
# Use an environment variable to control standalone mode
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,13 +17,60 @@ class Config:
|
||||
def __init__(self):
|
||||
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
||||
# 路径映射字典, target to link mapping
|
||||
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
|
||||
# Path mapping dictionary, target to link mapping
|
||||
self._path_mappings = {}
|
||||
# 静态路由映射字典, target to route mapping
|
||||
# Static route mapping dictionary, target to route mapping
|
||||
self._route_mappings = {}
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
# 在初始化时扫描符号链接
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
self.embeddings_roots = None
|
||||
self.base_models_roots = self._init_checkpoint_paths()
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._scan_symbolic_links()
|
||||
|
||||
if not standalone_mode:
|
||||
# Save the paths to settings.json when running in ComfyUI mode
|
||||
self.save_folder_paths_to_settings()
|
||||
|
||||
def save_folder_paths_to_settings(self):
|
||||
"""Save folder paths to settings.json for standalone mode to use later"""
|
||||
try:
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': self.loras_roots,
|
||||
'checkpoints': self.checkpoints_roots,
|
||||
'unet': self.unet_roots,
|
||||
'embeddings': self.embeddings_roots,
|
||||
}
|
||||
|
||||
# Add default roots if there's only one item and key doesn't exist
|
||||
if len(self.loras_roots) == 1 and "default_lora_root" not in settings:
|
||||
settings["default_lora_root"] = self.loras_roots[0]
|
||||
|
||||
if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings:
|
||||
settings["default_checkpoint_root"] = self.checkpoints_roots[0]
|
||||
|
||||
if self.embeddings_roots and len(self.embeddings_roots) == 1 and "default_embedding_root" not in settings:
|
||||
settings["default_embedding_root"] = self.embeddings_roots[0]
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
def _is_link(self, path: str) -> bool:
|
||||
try:
|
||||
@@ -38,12 +90,18 @@ class Config:
|
||||
return False
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""扫描所有 LoRA 根目录中的符号链接"""
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
for root in self.loras_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.base_models_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.embeddings_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
def _scan_directory_links(self, root: str):
|
||||
"""递归扫描目录中的符号链接"""
|
||||
"""Recursively scan symbolic links in a directory"""
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
@@ -58,63 +116,178 @@ class Config:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
|
||||
def add_path_mapping(self, link_path: str, target_path: str):
|
||||
"""添加符号链接路径映射
|
||||
target_path: 实际目标路径
|
||||
link_path: 符号链接路径
|
||||
"""Add a symbolic link path mapping
|
||||
target_path: actual target path
|
||||
link_path: symbolic link path
|
||||
"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||
# 保持原有的映射关系:目标路径 -> 链接路径
|
||||
# Keep the original mapping: target path -> link path
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
|
||||
def add_route_mapping(self, path: str, route: str):
|
||||
"""添加静态路由映射"""
|
||||
"""Add a static route mapping"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
self._route_mappings[normalized_path] = route
|
||||
logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
|
||||
def map_path_to_link(self, path: str) -> str:
|
||||
"""将目标路径映射回符号链接路径"""
|
||||
"""Map a target path back to its symbolic link path"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_path.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为链接路径
|
||||
# If the path starts with the target path, replace with link path
|
||||
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return path
|
||||
|
||||
def map_link_to_path(self, link_path: str) -> str:
|
||||
"""Map a symbolic link path back to the actual path"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_link.startswith(target_path):
|
||||
# If the path starts with the target path, replace with actual path
|
||||
mapped_path = normalized_link.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return link_path
|
||||
|
||||
def _init_lora_paths(self) -> List[str]:
|
||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||
paths = list(set(path.replace(os.sep, "/")
|
||||
for path in folder_paths.get_folder_paths("loras")
|
||||
if os.path.exists(path)))
|
||||
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
|
||||
|
||||
if not paths:
|
||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||
|
||||
# 初始化路径映射
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
|
||||
return paths
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("loras")
|
||||
|
||||
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||
path_map = {}
|
||||
for path in raw_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid loras folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing LoRA paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_checkpoint_paths(self) -> List[str]:
|
||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||
try:
|
||||
# Get checkpoint paths from folder_paths
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Normalize and resolve symlinks for checkpoints, store mapping from resolved -> original
|
||||
checkpoint_map = {}
|
||||
for path in raw_checkpoint_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
checkpoint_map[real_path] = checkpoint_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Normalize and resolve symlinks for unet, store mapping from resolved -> original
|
||||
unet_map = {}
|
||||
for path in raw_unet_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Merge both maps and deduplicate by real path
|
||||
merged_map = {}
|
||||
for real_path, orig_path in {**checkpoint_map, **unet_map}.items():
|
||||
if real_path not in merged_map:
|
||||
merged_map[real_path] = orig_path
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
|
||||
|
||||
# Split back into checkpoints and unet roots for class properties
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_map.values()]
|
||||
|
||||
all_paths = unique_paths
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
|
||||
|
||||
if not all_paths:
|
||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
# Initialize path mappings
|
||||
for original_path in all_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return all_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing checkpoint paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_embedding_paths(self) -> List[str]:
|
||||
"""Initialize and validate embedding paths from ComfyUI settings"""
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("embeddings")
|
||||
|
||||
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||
path_map = {}
|
||||
for path in raw_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid embeddings folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing embedding paths: {e}")
|
||||
return []
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
"""Convert local preview path to static URL"""
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
real_path = os.path.realpath(preview_path).replace(os.sep, '/')
|
||||
|
||||
|
||||
# Find longest matching path (most specific match)
|
||||
best_match = ""
|
||||
best_route = ""
|
||||
|
||||
for path, route in self._route_mappings.items():
|
||||
if real_path.startswith(path):
|
||||
relative_path = os.path.relpath(real_path, path)
|
||||
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||
|
||||
if real_path.startswith(path) and len(path) > len(best_match):
|
||||
best_match = path
|
||||
best_route = route
|
||||
|
||||
if best_match:
|
||||
relative_path = os.path.relpath(real_path, best_match).replace(os.sep, '/')
|
||||
safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')]
|
||||
safe_path = '/'.join(safe_parts)
|
||||
return f'{best_route}/{safe_path}'
|
||||
|
||||
return ""
|
||||
|
||||
# Global config instance
|
||||
|
||||
@@ -1,25 +1,63 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from server import PromptServer # type: ignore
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .services.lora_scanner import LoraScanner
|
||||
from .services.file_monitor import LoraFileMonitor
|
||||
from .services.lora_cache import LoraCache
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check if we're in standalone mode
|
||||
STANDALONE_MODE = 'nodes' not in sys.modules
|
||||
|
||||
class LoraManager:
|
||||
"""Main entry point for LoRA Manager plugin"""
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
"""Initialize and register all routes"""
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
added_targets = set() # 用于跟踪已添加的目标路径
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
example_images_path = settings.get('example_images_path')
|
||||
logger.info(f"Example images path: {example_images_path}")
|
||||
if example_images_path and os.path.exists(example_images_path):
|
||||
app.router.add_static('/example_images_static', example_images_path)
|
||||
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}")
|
||||
|
||||
# Add static routes for each lora root
|
||||
for idx, root in enumerate(config.loras_roots, start=1):
|
||||
@@ -31,77 +69,326 @@ class LoraManager:
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# 为原始路径添加静态路由
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# 记录路由映射
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# 为符号链接的目标路径添加额外的静态路由
|
||||
link_idx = 1
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(config.embeddings_roots, start=1):
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for symlink target paths
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
if target_path not in added_targets:
|
||||
route_path = f'/loras_static/link_{link_idx}/preview'
|
||||
app.router.add_static(route_path, target_path)
|
||||
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(target_path)
|
||||
link_idx += 1
|
||||
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
|
||||
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
|
||||
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
|
||||
try:
|
||||
app.router.add_static(route_path, Path(target_path).resolve(strict=False))
|
||||
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(target_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
|
||||
continue
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
app.router.add_static('/locales', config.i18n_path)
|
||||
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
|
||||
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
routes = LoraRoutes()
|
||||
# Register default model types with the factory
|
||||
register_default_model_types()
|
||||
|
||||
# Setup file monitoring
|
||||
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
|
||||
monitor.start()
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app, monitor)
|
||||
# Setup non-model-specific routes
|
||||
stats_routes = StatsRoutes()
|
||||
stats_routes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
|
||||
# Store monitor in app for cleanup
|
||||
app['lora_monitor'] = monitor
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule cache initialization using the application's startup handler
|
||||
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||
|
||||
@classmethod
|
||||
async def _schedule_cache_init(cls, scanner: LoraScanner):
|
||||
"""Schedule cache initialization in the running event loop"""
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
# 创建低优先级的初始化任务
|
||||
asyncio.create_task(cls._initialize_cache(scanner), name='lora_cache_init')
|
||||
except Exception as e:
|
||||
print(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
||||
|
||||
@classmethod
|
||||
async def _initialize_cache(cls, scanner: LoraScanner):
|
||||
"""Initialize cache in background"""
|
||||
try:
|
||||
# 设置初始缓存占位
|
||||
scanner._cache = LoraCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Register DownloadManager with ServiceRegistry
|
||||
await ServiceRegistry.get_download_manager()
|
||||
|
||||
from .services.metadata_service import initialize_metadata_providers
|
||||
await initialize_metadata_providers()
|
||||
|
||||
# Initialize WebSocket manager
|
||||
await ServiceRegistry.get_websocket_manager()
|
||||
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
init_tasks = [
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
|
||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
]
|
||||
|
||||
await ExampleImagesMigration.check_and_run_migrations()
|
||||
|
||||
# Schedule post-initialization tasks to run after scanners complete
|
||||
asyncio.create_task(
|
||||
cls._run_post_initialization_tasks(init_tasks),
|
||||
name='post_init_tasks'
|
||||
)
|
||||
|
||||
# 分阶段加载缓存
|
||||
await scanner.get_cached_data(force_refresh=True)
|
||||
logger.debug("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
|
||||
except Exception as e:
|
||||
print(f"LoRA Manager: Error initializing cache: {e}")
|
||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _run_post_initialization_tasks(cls, init_tasks):
|
||||
"""Run post-initialization tasks after all scanners complete"""
|
||||
try:
|
||||
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
|
||||
|
||||
# Wait for all scanner initialization tasks to complete
|
||||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||
|
||||
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
|
||||
|
||||
# Run post-initialization tasks
|
||||
post_tasks = [
|
||||
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
||||
# Add more post-initialization tasks here as needed
|
||||
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
||||
]
|
||||
|
||||
# Run all post-initialization tasks
|
||||
results = await asyncio.gather(*post_tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
for i, result in enumerate(results):
|
||||
task_name = post_tasks[i].get_name()
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
|
||||
else:
|
||||
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
|
||||
|
||||
logger.debug("LoRA Manager: All post-initialization tasks completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files(cls):
|
||||
"""Clean up .bak files in all model roots"""
|
||||
try:
|
||||
logger.debug("Starting cleanup of .bak files in model directories...")
|
||||
|
||||
# Collect all model roots
|
||||
all_roots = set()
|
||||
all_roots.update(config.loras_roots)
|
||||
all_roots.update(config.base_models_roots)
|
||||
all_roots.update(config.embeddings_roots)
|
||||
|
||||
total_deleted = 0
|
||||
total_size_freed = 0
|
||||
|
||||
for root_path in all_roots:
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
|
||||
total_deleted += deleted_count
|
||||
total_size_freed += size_freed
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
|
||||
|
||||
# Yield control periodically
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
|
||||
else:
|
||||
logger.debug("Backup cleanup completed: no .bak files found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files_in_directory(cls, directory_path: str):
|
||||
"""Clean up .bak files in a specific directory recursively
|
||||
|
||||
Args:
|
||||
directory_path: Path to the directory to clean
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (number of files deleted, total size freed in bytes)
|
||||
"""
|
||||
deleted_count = 0
|
||||
size_freed = 0
|
||||
visited_paths = set()
|
||||
|
||||
def cleanup_recursive(path):
|
||||
nonlocal deleted_count, size_freed
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
|
||||
file_size = entry.stat().st_size
|
||||
os.remove(entry.path)
|
||||
deleted_count += 1
|
||||
size_freed += file_size
|
||||
logger.debug(f"Deleted .bak file: {entry.path}")
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
cleanup_recursive(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory {path} for .bak files: {e}")
|
||||
|
||||
# Run the recursive cleanup in a thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, cleanup_recursive, directory_path)
|
||||
|
||||
return deleted_count, size_freed
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_example_images_folders(cls):
|
||||
"""Invoke the example images cleanup service for manual execution."""
|
||||
try:
|
||||
service = ExampleImagesCleanupService()
|
||||
result = await service.cleanup_example_image_folders()
|
||||
|
||||
if result.get('success'):
|
||||
logger.debug(
|
||||
"Manual example images cleanup completed: moved=%s",
|
||||
result.get('moved_total'),
|
||||
)
|
||||
elif result.get('partial_success'):
|
||||
logger.warning(
|
||||
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
|
||||
result.get('moved_total'),
|
||||
result.get('move_failures'),
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Manual example images cleanup skipped or failed: %s",
|
||||
result.get('error', 'no changes'),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e: # pragma: no cover - defensive guard
|
||||
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'error_code': 'unexpected_error',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def _cleanup(cls, app):
|
||||
"""Cleanup resources"""
|
||||
if 'lora_monitor' in app:
|
||||
app['lora_monitor'].stop()
|
||||
"""Cleanup resources using ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Cleaning up services")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
30
py/metadata_collector/__init__.py
Normal file
30
py/metadata_collector/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
if not standalone_mode:
|
||||
from .metadata_hook import MetadataHook
|
||||
from .metadata_registry import MetadataRegistry
|
||||
|
||||
def init():
|
||||
# Install hooks to collect metadata during execution
|
||||
MetadataHook.install()
|
||||
|
||||
# Initialize registry
|
||||
registry = MetadataRegistry()
|
||||
|
||||
print("ComfyUI Metadata Collector initialized")
|
||||
|
||||
def get_metadata(prompt_id=None):
|
||||
"""Helper function to get metadata from the registry"""
|
||||
registry = MetadataRegistry()
|
||||
return registry.get_metadata(prompt_id)
|
||||
else:
|
||||
# Standalone mode - provide dummy implementations
|
||||
def init():
|
||||
print("ComfyUI Metadata Collector disabled in standalone mode")
|
||||
|
||||
def get_metadata(prompt_id=None):
|
||||
"""Dummy implementation for standalone mode"""
|
||||
return {}
|
||||
13
py/metadata_collector/constants.py
Normal file
13
py/metadata_collector/constants.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Constants used by the metadata collector"""
|
||||
|
||||
# Metadata categories
|
||||
MODELS = "models"
|
||||
PROMPTS = "prompts"
|
||||
SAMPLING = "sampling"
|
||||
LORAS = "loras"
|
||||
SIZE = "size"
|
||||
IMAGES = "images"
|
||||
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
|
||||
|
||||
# Complete list of categories to track
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||
204
py/metadata_collector/metadata_hook.py
Normal file
204
py/metadata_collector/metadata_hook.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import sys
|
||||
import inspect
|
||||
from .metadata_registry import MetadataRegistry
|
||||
|
||||
class MetadataHook:
|
||||
"""Install hooks for metadata collection"""
|
||||
|
||||
@staticmethod
|
||||
def install():
|
||||
"""Install hooks to collect metadata during execution"""
|
||||
try:
|
||||
# Import ComfyUI's execution module
|
||||
execution = None
|
||||
try:
|
||||
# Try direct import first
|
||||
import execution # type: ignore
|
||||
except ImportError:
|
||||
# Try to locate from system modules
|
||||
for module_name in sys.modules:
|
||||
if module_name.endswith('.execution'):
|
||||
execution = sys.modules[module_name]
|
||||
break
|
||||
|
||||
# If we can't find the execution module, we can't install hooks
|
||||
if execution is None:
|
||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
return
|
||||
|
||||
# Detect whether we're using the new async version of ComfyUI
|
||||
is_async = False
|
||||
map_node_func_name = '_map_node_over_list'
|
||||
|
||||
if hasattr(execution, '_async_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
|
||||
map_node_func_name = '_async_map_node_over_list'
|
||||
elif hasattr(execution, '_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||
|
||||
if is_async:
|
||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||
else:
|
||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
MetadataHook._install_sync_hooks(execution)
|
||||
|
||||
print("Metadata collection hooks installed for runtime values")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error installing metadata hooks: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _install_sync_hooks(execution):
|
||||
"""Install hooks for synchronous execution model"""
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
|
||||
@staticmethod
|
||||
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
|
||||
"""Install hooks for asynchronous execution model"""
|
||||
# Store the original _async_map_node_over_list function
|
||||
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||
|
||||
# Wrapped async function, compatible with both stable and nightly
|
||||
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
|
||||
hidden_inputs = kwargs.get('hidden_inputs', None)
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Call original function with all args/kwargs
|
||||
results = await original_map_node_over_list(
|
||||
prompt_id, unique_id, obj, input_data_all, func,
|
||||
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
|
||||
)
|
||||
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return await original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions with async versions
|
||||
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||
execution.execute = async_execute_with_prompt_tracking
|
||||
479
py/metadata_collector/metadata_processor.py
Normal file
479
py/metadata_collector/metadata_processor.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import json
|
||||
import os
|
||||
from .constants import IMAGES
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
|
||||
|
||||
class MetadataProcessor:
|
||||
"""Process and format collected metadata"""
|
||||
|
||||
@staticmethod
|
||||
def find_primary_sampler(metadata, downstream_id=None):
|
||||
"""
|
||||
Find the primary KSampler node that executed before the given downstream node
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
if downstream_id is None:
|
||||
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
||||
downstream_id = metadata[IMAGES]["first_decode"]["node_id"]
|
||||
|
||||
# If we have a downstream_id and execution_order, use it to narrow down potential samplers
|
||||
if downstream_id and "execution_order" in metadata:
|
||||
execution_order = metadata["execution_order"]
|
||||
|
||||
# Find the index of the downstream node in the execution order
|
||||
if downstream_id in execution_order:
|
||||
downstream_index = execution_order.index(downstream_id)
|
||||
|
||||
# Extract all sampler nodes that executed before the downstream node
|
||||
candidate_samplers = {}
|
||||
for i in range(downstream_index):
|
||||
node_id = execution_order[i]
|
||||
# Use IS_SAMPLER flag to identify true sampler nodes
|
||||
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
|
||||
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
|
||||
|
||||
# If we found candidate samplers, apply primary sampler logic to these candidates only
|
||||
if candidate_samplers:
|
||||
# Collect potential primary samplers based on different criteria
|
||||
custom_advanced_samplers = []
|
||||
advanced_add_noise_samplers = []
|
||||
high_denoise_samplers = []
|
||||
max_denoise = -1
|
||||
high_denoise_id = None
|
||||
|
||||
# First, check for SamplerCustomAdvanced among candidates
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt and prompt.original_prompt:
|
||||
for node_id in candidate_samplers:
|
||||
node_info = prompt.original_prompt.get(node_id, {})
|
||||
if node_info.get("class_type") == "SamplerCustomAdvanced":
|
||||
custom_advanced_samplers.append(node_id)
|
||||
|
||||
# Next, check for KSamplerAdvanced with add_noise="enable" among candidates
|
||||
for node_id, sampler_info in candidate_samplers.items():
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
add_noise = parameters.get("add_noise")
|
||||
if add_noise == "enable":
|
||||
advanced_add_noise_samplers.append(node_id)
|
||||
|
||||
# Find the sampler with highest denoise value among candidates
|
||||
for node_id, sampler_info in candidate_samplers.items():
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
denoise = parameters.get("denoise")
|
||||
if denoise is not None and denoise > max_denoise:
|
||||
max_denoise = denoise
|
||||
high_denoise_id = node_id
|
||||
|
||||
if high_denoise_id:
|
||||
high_denoise_samplers.append(high_denoise_id)
|
||||
|
||||
# Combine all potential primary samplers
|
||||
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
|
||||
|
||||
# Find the most recent potential primary sampler (closest to downstream node)
|
||||
for i in range(downstream_index - 1, -1, -1):
|
||||
node_id = execution_order[i]
|
||||
if node_id in potential_samplers:
|
||||
return node_id, candidate_samplers[node_id]
|
||||
|
||||
# If no potential sampler found from our criteria, return the most recent sampler
|
||||
if candidate_samplers:
|
||||
for i in range(downstream_index - 1, -1, -1):
|
||||
node_id = execution_order[i]
|
||||
if node_id in candidate_samplers:
|
||||
return node_id, candidate_samplers[node_id]
|
||||
|
||||
# If no downstream_id provided or no suitable sampler found, fall back to original logic
|
||||
primary_sampler = None
|
||||
primary_sampler_id = None
|
||||
max_denoise = -1
|
||||
|
||||
# First, check for SamplerCustomAdvanced
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt and prompt.original_prompt:
|
||||
for node_id, node_info in prompt.original_prompt.items():
|
||||
if node_info.get("class_type") == "SamplerCustomAdvanced":
|
||||
# Check if the node is in SAMPLING and has IS_SAMPLER flag
|
||||
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
|
||||
return node_id, metadata[SAMPLING][node_id]
|
||||
|
||||
# Next, check for KSamplerAdvanced with add_noise="enable" using IS_SAMPLER flag
|
||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||
# Skip if not marked as a sampler
|
||||
if not sampler_info.get(IS_SAMPLER, False):
|
||||
continue
|
||||
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
add_noise = parameters.get("add_noise")
|
||||
if add_noise == "enable":
|
||||
primary_sampler = sampler_info
|
||||
primary_sampler_id = node_id
|
||||
break
|
||||
|
||||
# If no specialized sampler found, find the sampler with highest denoise value
|
||||
if primary_sampler is None:
|
||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||
# Skip if not marked as a sampler
|
||||
if not sampler_info.get(IS_SAMPLER, False):
|
||||
continue
|
||||
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
denoise = parameters.get("denoise")
|
||||
if denoise is not None and denoise > max_denoise:
|
||||
max_denoise = denoise
|
||||
primary_sampler = sampler_info
|
||||
primary_sampler_id = node_id
|
||||
|
||||
return primary_sampler_id, primary_sampler
|
||||
|
||||
@staticmethod
|
||||
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
|
||||
"""
|
||||
Trace an input connection from a node to find the source node
|
||||
|
||||
Parameters:
|
||||
- prompt: The prompt object containing node connections
|
||||
- node_id: ID of the starting node
|
||||
- input_name: Name of the input to trace
|
||||
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
|
||||
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
- node_id of the found node, or None if not found
|
||||
"""
|
||||
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
|
||||
return None
|
||||
|
||||
# For depth tracking
|
||||
current_depth = 0
|
||||
|
||||
current_node_id = node_id
|
||||
current_input = input_name
|
||||
|
||||
# If we're just tracing to origin (no target_class), keep track of the last valid node
|
||||
last_valid_node = None
|
||||
|
||||
while current_depth < max_depth:
|
||||
if current_node_id not in prompt.original_prompt:
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
|
||||
if current_input not in node_inputs:
|
||||
# We've reached a node without the specified input - this is our origin node
|
||||
# if we're not looking for a specific target_class
|
||||
return current_node_id if not target_class else None
|
||||
|
||||
input_value = node_inputs[current_input]
|
||||
# Input connections are formatted as [node_id, output_index]
|
||||
if isinstance(input_value, list) and len(input_value) >= 2:
|
||||
found_node_id = input_value[0] # Connected node_id
|
||||
|
||||
# If we're looking for a specific node class
|
||||
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
|
||||
return found_node_id
|
||||
|
||||
# If we're not looking for a specific class, update the last valid node
|
||||
if not target_class:
|
||||
last_valid_node = found_node_id
|
||||
|
||||
# Continue tracing through intermediate nodes
|
||||
current_node_id = found_node_id
|
||||
# For most conditioning nodes, the input we want to follow is named "conditioning"
|
||||
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
|
||||
current_input = "conditioning"
|
||||
else:
|
||||
# If there's no "conditioning" input, return the current node
|
||||
# if we're not looking for a specific target_class
|
||||
return found_node_id if not target_class else None
|
||||
else:
|
||||
# We've reached a node with no further connections
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
current_depth += 1
|
||||
|
||||
# If we've reached max depth without finding target_class
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
@staticmethod
|
||||
def find_primary_checkpoint(metadata):
|
||||
"""Find the primary checkpoint model in the workflow"""
|
||||
if not metadata.get(MODELS):
|
||||
return None
|
||||
|
||||
# In most workflows, there's only one checkpoint, so we can just take the first one
|
||||
for node_id, model_info in metadata.get(MODELS, {}).items():
|
||||
if model_info.get("type") == "checkpoint":
|
||||
return model_info.get("name")
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def match_conditioning_to_prompts(metadata, sampler_id):
|
||||
"""
|
||||
Match conditioning objects from a sampler to prompts in metadata
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- sampler_id: ID of the sampler node to match
|
||||
|
||||
Returns:
|
||||
- Dictionary with 'prompt' and 'negative_prompt' if found
|
||||
"""
|
||||
result = {
|
||||
"prompt": "",
|
||||
"negative_prompt": ""
|
||||
}
|
||||
|
||||
# Check if we have stored conditioning objects for this sampler
|
||||
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
|
||||
return ""
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def extract_generation_params(metadata, id=None):
|
||||
"""
|
||||
Extract generation parameters from metadata using node relationships
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
params = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"seed": None,
|
||||
"steps": None,
|
||||
"cfg_scale": None,
|
||||
# "guidance": None, # Add guidance parameter
|
||||
"sampler": None,
|
||||
"scheduler": None,
|
||||
"checkpoint": None,
|
||||
"loras": "",
|
||||
"size": None,
|
||||
"clip_skip": None
|
||||
}
|
||||
|
||||
# Get the prompt object for node relationship tracing
|
||||
prompt = metadata.get("current_prompt")
|
||||
|
||||
# Find the primary KSampler node
|
||||
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
||||
|
||||
# Directly get checkpoint from metadata instead of tracing
|
||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
|
||||
if checkpoint:
|
||||
params["checkpoint"] = checkpoint
|
||||
|
||||
# Check if guidance parameter exists in any sampling node
|
||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||
parameters = sampler_info.get("parameters", {})
|
||||
if "guidance" in parameters and parameters["guidance"] is not None:
|
||||
params["guidance"] = parameters["guidance"]
|
||||
break
|
||||
|
||||
if primary_sampler:
|
||||
# Extract sampling parameters
|
||||
sampling_params = primary_sampler.get("parameters", {})
|
||||
# Handle both seed and noise_seed
|
||||
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
|
||||
params["steps"] = sampling_params.get("steps")
|
||||
params["cfg_scale"] = sampling_params.get("cfg")
|
||||
params["sampler"] = sampling_params.get("sampler_name")
|
||||
params["scheduler"] = sampling_params.get("scheduler")
|
||||
|
||||
if prompt and primary_sampler_id:
|
||||
# Check if this is a SamplerCustomAdvanced node
|
||||
is_custom_advanced = False
|
||||
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
|
||||
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||
|
||||
if is_custom_advanced:
|
||||
# For SamplerCustomAdvanced, use the new handler method
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
else:
|
||||
# For standard samplers, match conditioning objects to prompts
|
||||
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
|
||||
params["prompt"] = prompt_results["prompt"]
|
||||
params["negative_prompt"] = prompt_results["negative_prompt"]
|
||||
|
||||
# If prompts were still not found, fall back to tracing connections
|
||||
if not params["prompt"]:
|
||||
# Original tracing for standard samplers
|
||||
# Trace positive prompt - look specifically for CLIPTextEncode
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
else:
|
||||
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
|
||||
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
|
||||
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
|
||||
|
||||
# Trace negative prompt - look specifically for CLIPTextEncode
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
|
||||
# For SamplerCustom, handle any additional parameters
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
# Size extraction is same for all sampler types
|
||||
# Check if the sampler itself has size information (from latent_image)
|
||||
if primary_sampler_id in metadata.get(SIZE, {}):
|
||||
width = metadata[SIZE][primary_sampler_id].get("width")
|
||||
height = metadata[SIZE][primary_sampler_id].get("height")
|
||||
if width and height:
|
||||
params["size"] = f"{width}x{height}"
|
||||
|
||||
# Extract LoRAs using the standardized format
|
||||
lora_parts = []
|
||||
for node_id, lora_info in metadata.get(LORAS, {}).items():
|
||||
# Access the lora_list from the standardized format
|
||||
lora_list = lora_info.get("lora_list", [])
|
||||
for lora in lora_list:
|
||||
name = lora.get("name", "unknown")
|
||||
strength = lora.get("strength", 1.0)
|
||||
lora_parts.append(f"<lora:{name}:{strength}>")
|
||||
|
||||
params["loras"] = " ".join(lora_parts)
|
||||
|
||||
# Set default clip_skip value
|
||||
params["clip_skip"] = "1" # Common default
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def to_dict(metadata, id=None):
|
||||
"""
|
||||
Convert extracted metadata to the ComfyUI output.json format
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
if standalone_mode:
|
||||
# Return empty dictionary in standalone mode
|
||||
return {}
|
||||
|
||||
params = MetadataProcessor.extract_generation_params(metadata, id)
|
||||
|
||||
# Convert all values to strings to match output.json format
|
||||
for key in params:
|
||||
if params[key] is not None:
|
||||
params[key] = str(params[key])
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def to_json(metadata, id=None):
|
||||
"""Convert metadata to JSON string"""
|
||||
params = MetadataProcessor.to_dict(metadata, id)
|
||||
return json.dumps(params, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
|
||||
"""
|
||||
Handle parameter extraction for SamplerCustomAdvanced nodes
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- prompt: The prompt object containing node connections
|
||||
- primary_sampler_id: ID of the SamplerCustomAdvanced node
|
||||
- params: Parameters dictionary to update
|
||||
"""
|
||||
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
|
||||
return
|
||||
|
||||
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
|
||||
if "sigmas" in sampler_inputs:
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||
if "sampler" in sampler_inputs:
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
if "guider" in sampler_inputs:
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
275
py/metadata_collector/metadata_registry.py
Normal file
275
py/metadata_collector/metadata_registry.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import time
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
||||
from .constants import METADATA_CATEGORIES, IMAGES
|
||||
|
||||
class MetadataRegistry:
|
||||
"""A singleton registry to store and retrieve workflow metadata"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._reset()
|
||||
return cls._instance
|
||||
|
||||
def _reset(self):
|
||||
self.current_prompt_id = None
|
||||
self.current_prompt = None
|
||||
self.metadata = {}
|
||||
self.prompt_metadata = {}
|
||||
self.executed_nodes = set()
|
||||
|
||||
# Node-level cache for metadata
|
||||
self.node_cache = {}
|
||||
|
||||
# Limit the number of stored prompts
|
||||
self.max_prompt_history = 3
|
||||
|
||||
# Categories we want to track and retrieve from cache
|
||||
self.metadata_categories = METADATA_CATEGORIES
|
||||
|
||||
def _clean_old_prompts(self):
|
||||
"""Clean up old prompt metadata, keeping only recent ones"""
|
||||
if len(self.prompt_metadata) <= self.max_prompt_history:
|
||||
return
|
||||
|
||||
# Sort all prompt_ids by timestamp
|
||||
sorted_prompts = sorted(
|
||||
self.prompt_metadata.keys(),
|
||||
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
|
||||
)
|
||||
|
||||
# Remove oldest records
|
||||
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
|
||||
for pid in prompts_to_remove:
|
||||
del self.prompt_metadata[pid]
|
||||
|
||||
def start_collection(self, prompt_id):
|
||||
"""Begin metadata collection for a new prompt"""
|
||||
self.current_prompt_id = prompt_id
|
||||
self.executed_nodes = set()
|
||||
self.prompt_metadata[prompt_id] = {
|
||||
category: {} for category in METADATA_CATEGORIES
|
||||
}
|
||||
# Add additional metadata fields
|
||||
self.prompt_metadata[prompt_id].update({
|
||||
"execution_order": [],
|
||||
"current_prompt": None, # Will store the prompt object
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
# Clean up old prompt data
|
||||
self._clean_old_prompts()
|
||||
|
||||
def set_current_prompt(self, prompt):
|
||||
"""Set the current prompt object reference"""
|
||||
self.current_prompt = prompt
|
||||
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
|
||||
# Store the prompt in the metadata for later relationship tracing
|
||||
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
|
||||
|
||||
def get_metadata(self, prompt_id=None):
|
||||
"""Get collected metadata for a prompt"""
|
||||
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
||||
if key not in self.prompt_metadata:
|
||||
return {}
|
||||
|
||||
metadata = self.prompt_metadata[key]
|
||||
|
||||
# If we have a current prompt object, check for non-executed nodes
|
||||
prompt_obj = metadata.get("current_prompt")
|
||||
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
||||
original_prompt = prompt_obj.original_prompt
|
||||
|
||||
# Fill in missing metadata from cache for nodes that weren't executed
|
||||
self._fill_missing_metadata(key, original_prompt)
|
||||
|
||||
return self.prompt_metadata.get(key, {})
|
||||
|
||||
def _fill_missing_metadata(self, prompt_id, original_prompt):
|
||||
"""Fill missing metadata from cache for non-executed nodes"""
|
||||
if not original_prompt:
|
||||
return
|
||||
|
||||
executed_nodes = self.executed_nodes
|
||||
metadata = self.prompt_metadata[prompt_id]
|
||||
|
||||
# Iterate through nodes in the original prompt
|
||||
for node_id, node_data in original_prompt.items():
|
||||
# Skip if already executed in this run
|
||||
if node_id in executed_nodes:
|
||||
continue
|
||||
|
||||
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
|
||||
prompt_class_type = node_data.get("class_type")
|
||||
if not prompt_class_type:
|
||||
continue
|
||||
|
||||
# Convert to actual class name (which is what we use in our cache)
|
||||
class_type = prompt_class_type
|
||||
if prompt_class_type in NODE_CLASS_MAPPINGS:
|
||||
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
|
||||
class_type = class_obj.__name__
|
||||
|
||||
# Create cache key using the actual class name
|
||||
cache_key = f"{node_id}:{class_type}"
|
||||
|
||||
# Check if this node type is relevant for metadata collection
|
||||
if class_type in NODE_EXTRACTORS:
|
||||
# Check if we have cached metadata for this node
|
||||
if cache_key in self.node_cache:
|
||||
cached_data = self.node_cache[cache_key]
|
||||
|
||||
# Apply cached metadata to the current metadata
|
||||
for category in self.metadata_categories:
|
||||
if category in cached_data and node_id in cached_data[category]:
|
||||
if node_id not in metadata[category]:
|
||||
metadata[category][node_id] = cached_data[category][node_id]
|
||||
|
||||
def record_node_execution(self, node_id, class_type, inputs, outputs):
|
||||
"""Record information about a node's execution"""
|
||||
if not self.current_prompt_id:
|
||||
return
|
||||
|
||||
# Add to execution order and mark as executed
|
||||
if node_id not in self.executed_nodes:
|
||||
self.executed_nodes.add(node_id)
|
||||
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
|
||||
|
||||
# Process inputs to simplify working with them
|
||||
processed_inputs = {}
|
||||
for input_name, input_values in inputs.items():
|
||||
if isinstance(input_values, list) and len(input_values) > 0:
|
||||
# For single values, just use the first one (most common case)
|
||||
processed_inputs[input_name] = input_values[0]
|
||||
else:
|
||||
processed_inputs[input_name] = input_values
|
||||
|
||||
# Extract node-specific metadata
|
||||
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||
extractor.extract(
|
||||
node_id,
|
||||
processed_inputs,
|
||||
outputs,
|
||||
self.prompt_metadata[self.current_prompt_id]
|
||||
)
|
||||
|
||||
# Cache this node's metadata
|
||||
self._cache_node_metadata(node_id, class_type)
|
||||
|
||||
def update_node_execution(self, node_id, class_type, outputs):
|
||||
"""Update node metadata with output information"""
|
||||
if not self.current_prompt_id:
|
||||
return
|
||||
|
||||
# Process outputs to make them more usable
|
||||
processed_outputs = outputs
|
||||
|
||||
# Use the same extractor to update with outputs
|
||||
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
||||
if hasattr(extractor, 'update'):
|
||||
extractor.update(
|
||||
node_id,
|
||||
processed_outputs,
|
||||
self.prompt_metadata[self.current_prompt_id]
|
||||
)
|
||||
|
||||
# Update the cached metadata for this node
|
||||
self._cache_node_metadata(node_id, class_type)
|
||||
|
||||
def _cache_node_metadata(self, node_id, class_type):
|
||||
"""Cache the metadata for a specific node"""
|
||||
if not self.current_prompt_id or not node_id or not class_type:
|
||||
return
|
||||
|
||||
# Create a cache key combining node_id and class_type
|
||||
cache_key = f"{node_id}:{class_type}"
|
||||
|
||||
# Create a shallow copy of the node's metadata
|
||||
node_metadata = {}
|
||||
current_metadata = self.prompt_metadata[self.current_prompt_id]
|
||||
|
||||
for category in self.metadata_categories:
|
||||
if category in current_metadata and node_id in current_metadata[category]:
|
||||
if category not in node_metadata:
|
||||
node_metadata[category] = {}
|
||||
node_metadata[category][node_id] = current_metadata[category][node_id]
|
||||
|
||||
# Save to cache if we have any metadata for this node
|
||||
if any(node_metadata.values()):
|
||||
self.node_cache[cache_key] = node_metadata
|
||||
|
||||
def clear_unused_cache(self):
|
||||
"""Clean up node_cache entries that are no longer in use"""
|
||||
# Collect all node_ids currently in prompt_metadata
|
||||
active_node_ids = set()
|
||||
for prompt_data in self.prompt_metadata.values():
|
||||
for category in self.metadata_categories:
|
||||
if category in prompt_data:
|
||||
active_node_ids.update(prompt_data[category].keys())
|
||||
|
||||
# Find cache keys that are no longer needed
|
||||
keys_to_remove = []
|
||||
for cache_key in self.node_cache:
|
||||
node_id = cache_key.split(':')[0]
|
||||
if node_id not in active_node_ids:
|
||||
keys_to_remove.append(cache_key)
|
||||
|
||||
# Remove cache entries that are no longer needed
|
||||
for key in keys_to_remove:
|
||||
del self.node_cache[key]
|
||||
|
||||
def clear_metadata(self, prompt_id=None):
|
||||
"""Clear metadata for a specific prompt or reset all data"""
|
||||
if prompt_id is not None:
|
||||
if prompt_id in self.prompt_metadata:
|
||||
del self.prompt_metadata[prompt_id]
|
||||
# Clean up cache after removing prompt
|
||||
self.clear_unused_cache()
|
||||
else:
|
||||
# Reset all data
|
||||
self._reset()
|
||||
|
||||
def get_first_decoded_image(self, prompt_id=None):
|
||||
"""Get the first decoded image result"""
|
||||
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
||||
if key not in self.prompt_metadata:
|
||||
return None
|
||||
|
||||
metadata = self.prompt_metadata[key]
|
||||
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
||||
image_data = metadata[IMAGES]["first_decode"]["image"]
|
||||
|
||||
# If it's an image batch or tuple, handle various formats
|
||||
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
||||
# Return first element of list/tuple
|
||||
return image_data[0]
|
||||
|
||||
# If it's a tensor, return as is for processing in the route handler
|
||||
return image_data
|
||||
|
||||
# If no image is found in the current metadata, try to find it in the cache
|
||||
# This handles the case where VAEDecode was cached by ComfyUI and not executed
|
||||
prompt_obj = metadata.get("current_prompt")
|
||||
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
||||
original_prompt = prompt_obj.original_prompt
|
||||
for node_id, node_data in original_prompt.items():
|
||||
class_type = node_data.get("class_type")
|
||||
if class_type and class_type in NODE_CLASS_MAPPINGS:
|
||||
class_obj = NODE_CLASS_MAPPINGS[class_type]
|
||||
class_name = class_obj.__name__
|
||||
# Check if this is a VAEDecode node
|
||||
if class_name == "VAEDecode":
|
||||
# Try to find this node in the cache
|
||||
cache_key = f"{node_id}:{class_name}"
|
||||
if cache_key in self.node_cache:
|
||||
cached_data = self.node_cache[cache_key]
|
||||
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
|
||||
image_data = cached_data[IMAGES][node_id]["image"]
|
||||
# Handle different image formats
|
||||
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
||||
return image_data[0]
|
||||
return image_data
|
||||
|
||||
return None
|
||||
683
py/metadata_collector/node_extractors.py
Normal file
683
py/metadata_collector/node_extractors.py
Normal file
@@ -0,0 +1,683 @@
|
||||
import os
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER
|
||||
|
||||
|
||||
class NodeMetadataExtractor:
|
||||
"""Base class for node-specific metadata extraction"""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
"""Extract metadata from node inputs/outputs"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
"""Update metadata with node outputs after execution"""
|
||||
pass
|
||||
|
||||
class GenericNodeExtractor(NodeMetadataExtractor):
|
||||
"""Default extractor for nodes without specific handling"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
pass
|
||||
|
||||
class CheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "ckpt_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("ckpt_name")
|
||||
if model_name:
|
||||
metadata[MODELS][node_id] = {
|
||||
"name": model_name,
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class TSCCheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "ckpt_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("ckpt_name")
|
||||
if model_name:
|
||||
metadata[MODELS][node_id] = {
|
||||
"name": model_name,
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# For loader node has lora_stack input, like Efficient Loader from Efficient Nodes
|
||||
active_loras = []
|
||||
|
||||
# Process lora_stack if available
|
||||
if "lora_stack" in inputs:
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name from path (following the format in lora_loader.py)
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": model_strength
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# Extract positive and negative prompt text if available
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
# Store both positive and negative text
|
||||
metadata[PROMPTS][node_id]["positive_text"] = positive_text
|
||||
metadata[PROMPTS][node_id]["negative_text"] = negative_text
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Handle conditioning outputs from TSC_EfficientLoader
|
||||
# outputs is a list with [(model, positive_encoded, negative_encoded, {"samples":latent}, vae, clip, dependencies,)]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 3:
|
||||
positive_conditioning = first_output[1]
|
||||
negative_conditioning = first_output[2]
|
||||
|
||||
# Save both conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = positive_conditioning
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = negative_conditioning
|
||||
|
||||
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "text" not in inputs:
|
||||
return
|
||||
|
||||
text = inputs.get("text", "")
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"text": text,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@staticmethod
|
||||
def extract_sampling_params(node_id, inputs, metadata, param_keys):
|
||||
"""Extract sampling parameters from inputs"""
|
||||
sampling_params = {}
|
||||
for key in param_keys:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_conditioning(node_id, inputs, metadata):
|
||||
"""Extract conditioning objects from inputs"""
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
|
||||
# Save conditioning objects in metadata for later matching
|
||||
if pos_conditioning is not None or neg_conditioning is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
@staticmethod
|
||||
def extract_latent_dimensions(node_id, inputs, metadata):
|
||||
"""Extract dimensions from latent image"""
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class SamplerExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerAdvancedBasicPipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Store vae_decode setting for later use in update
|
||||
if inputs and "vae_decode" in inputs:
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
# Store the vae_decode setting
|
||||
metadata[SAMPLING][node_id]["vae_decode"] = inputs["vae_decode"]
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Check if vae_decode was set to "true"
|
||||
should_save_image = True
|
||||
if SAMPLING in metadata and node_id in metadata[SAMPLING]:
|
||||
vae_decode = metadata[SAMPLING][node_id].get("vae_decode")
|
||||
if vae_decode is not None:
|
||||
should_save_image = (vae_decode == "true")
|
||||
|
||||
# Skip image saving if vae_decode isn't "true"
|
||||
if not should_save_image:
|
||||
return
|
||||
|
||||
# Ensure IMAGES category exists
|
||||
if IMAGES not in metadata:
|
||||
metadata[IMAGES] = {}
|
||||
|
||||
# Extract output_images from the TSC sampler format
|
||||
# outputs = [{"ui": {"images": preview_images}, "result": result}]
|
||||
# where result = (original_model, original_positive, original_negative, latent_list, optional_vae, output_images,)
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
# Get the first item in the list
|
||||
output_item = outputs[0]
|
||||
if isinstance(output_item, dict) and "result" in output_item:
|
||||
result = output_item["result"]
|
||||
if isinstance(result, tuple) and len(result) >= 6:
|
||||
# The output_images is the last element in the result tuple
|
||||
output_images = (result[5],)
|
||||
|
||||
# Save image data under node ID index to be captured by caching mechanism
|
||||
metadata[IMAGES][node_id] = {
|
||||
"node_id": node_id,
|
||||
"image": output_images
|
||||
}
|
||||
|
||||
# Only set first_decode if it hasn't been recorded yet
|
||||
if "first_decode" not in metadata[IMAGES]:
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
|
||||
|
||||
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
|
||||
class LoraLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "lora_name" not in inputs:
|
||||
return
|
||||
|
||||
lora_name = inputs.get("lora_name")
|
||||
# Extract base filename without extension from path
|
||||
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
|
||||
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
|
||||
|
||||
# Use the standardized format with lora_list
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": [
|
||||
{
|
||||
"name": lora_name,
|
||||
"strength": strength_model
|
||||
}
|
||||
],
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class ImageSizeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
width = inputs.get("width", 512)
|
||||
height = inputs.get("height", 512)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
active_loras = []
|
||||
|
||||
# Process lora_stack if available
|
||||
if "lora_stack" in inputs:
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name from path (following the format in lora_loader.py)
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": model_strength
|
||||
})
|
||||
|
||||
# Process loras from inputs
|
||||
if "loras" in inputs:
|
||||
loras_data = inputs.get("loras", [])
|
||||
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
loras_list = loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
loras_list = loras_data
|
||||
else:
|
||||
loras_list = []
|
||||
|
||||
# Filter for active loras
|
||||
for lora in loras_list:
|
||||
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
|
||||
active_loras.append({
|
||||
"name": lora.get("name", ""),
|
||||
"strength": float(lora.get("strength", 1.0))
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class FluxGuidanceExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "guidance" not in inputs:
|
||||
return
|
||||
|
||||
guidance_value = inputs.get("guidance")
|
||||
|
||||
# Store the guidance value in SAMPLING category
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
||||
|
||||
class UNETLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "unet_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("unet_name")
|
||||
if model_name:
|
||||
metadata[MODELS][node_id] = {
|
||||
"name": model_name,
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class VAEDecodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Ensure IMAGES category exists
|
||||
if IMAGES not in metadata:
|
||||
metadata[IMAGES] = {}
|
||||
|
||||
# Save image data under node ID index to be captured by caching mechanism
|
||||
metadata[IMAGES][node_id] = {
|
||||
"node_id": node_id,
|
||||
"image": outputs
|
||||
}
|
||||
|
||||
# Only set first_decode if it hasn't been recorded yet
|
||||
if "first_decode" not in metadata[IMAGES]:
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class KSamplerSelectExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "sampler_name" not in inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
if "sampler_name" in inputs:
|
||||
sampling_params["sampler_name"] = inputs["sampler_name"]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class BasicSchedulerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
for key in ["scheduler", "steps", "denoise"]:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
|
||||
# Handle noise.seed as seed
|
||||
if "noise" in inputs and inputs["noise"] is not None and hasattr(inputs["noise"], "seed"):
|
||||
noise = inputs["noise"]
|
||||
sampling_params["seed"] = noise.seed
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "clip_l" not in inputs or "t5xxl" not in inputs:
|
||||
return
|
||||
|
||||
clip_l_text = inputs.get("clip_l", "")
|
||||
t5xxl_text = inputs.get("t5xxl", "")
|
||||
|
||||
# If both are empty, use empty string
|
||||
if not clip_l_text and not t5xxl_text:
|
||||
combined_text = ""
|
||||
# If one is empty, use the non-empty one
|
||||
elif not clip_l_text:
|
||||
combined_text = t5xxl_text
|
||||
elif not t5xxl_text:
|
||||
combined_text = clip_l_text
|
||||
# If both have content, use JSON format
|
||||
else:
|
||||
combined_text = json.dumps({
|
||||
"T5": t5xxl_text,
|
||||
"CLIP-L": clip_l_text
|
||||
})
|
||||
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"text": combined_text,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# Extract guidance value if available
|
||||
if "guidance" in inputs:
|
||||
guidance_value = inputs.get("guidance")
|
||||
|
||||
# Store the guidance value in SAMPLING category
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "cfg" not in inputs:
|
||||
return
|
||||
|
||||
cfg_value = inputs.get("cfg")
|
||||
|
||||
# Store the cfg value in SAMPLING category
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
||||
|
||||
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Save the original conditioning inputs
|
||||
base_positive = inputs.get("base_positive")
|
||||
base_negative = inputs.get("base_negative")
|
||||
|
||||
if base_positive is not None or base_negative is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
|
||||
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Extract transformed conditionings from outputs
|
||||
# outputs structure: [(base_positive, base_negative, show_help, )]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 2:
|
||||
transformed_positive = first_output[0]
|
||||
transformed_negative = first_output[1]
|
||||
|
||||
# Save transformed conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
|
||||
|
||||
# Registry of node-specific extractors
|
||||
# Keys are node class names
|
||||
NODE_EXTRACTORS = {
|
||||
# Sampling
|
||||
"KSampler": SamplerExtractor,
|
||||
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||
"SamplerCustom": KSamplerAdvancedExtractor,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||
"ClownsharKSampler_Beta": SamplerExtractor,
|
||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack
|
||||
"KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack
|
||||
# Sampling Selectors
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||
# Loaders
|
||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
"LoraManagerLoader": LoraLoaderManagerExtractor,
|
||||
# Conditioning
|
||||
"CLIPTextEncode": CLIPTextEncodeExtractor,
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
|
||||
"CFGGuider": CFGGuiderExtractor, # Add CFGGuider
|
||||
# Image
|
||||
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
|
||||
# Add other nodes as needed
|
||||
}
|
||||
1
py/middleware/__init__.py
Normal file
1
py/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Server middleware modules"""
|
||||
53
py/middleware/cache_middleware.py
Normal file
53
py/middleware/cache_middleware.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Cache control middleware for ComfyUI server"""
|
||||
|
||||
from aiohttp import web
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
# Time in seconds
|
||||
ONE_HOUR: int = 3600
|
||||
ONE_DAY: int = 86400
|
||||
IMG_EXTENSIONS = (
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".ppm",
|
||||
".bmp",
|
||||
".pgm",
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
".mp4"
|
||||
)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(
|
||||
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> web.Response:
|
||||
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
|
||||
response: web.Response = await handler(request)
|
||||
|
||||
if (
|
||||
request.path.endswith(".js")
|
||||
or request.path.endswith(".css")
|
||||
or request.path.endswith("index.json")
|
||||
):
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
if not request.path.lower().endswith(IMG_EXTENSIONS):
|
||||
return response
|
||||
|
||||
# Handle image files
|
||||
if response.status == 404:
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
|
||||
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
|
||||
# Success responses and permanent redirects - cache for 1 day
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
|
||||
elif response.status in (302, 303, 307):
|
||||
# Temporary redirects - no cache
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
# Note: 304 Not Modified falls through - no cache headers set
|
||||
|
||||
return response
|
||||
45
py/nodes/debug_metadata.py
Normal file
45
py/nodes/debug_metadata.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
from server import PromptServer # type: ignore
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DebugMetadata:
|
||||
NAME = "Debug Metadata (LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Debug node to verify metadata_processor functionality"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "process_metadata"
|
||||
|
||||
def process_metadata(self, images, id):
|
||||
try:
|
||||
# Get the current execution context's metadata
|
||||
from ..metadata_collector import get_metadata
|
||||
metadata = get_metadata()
|
||||
|
||||
# Use the MetadataProcessor to convert it to JSON string
|
||||
metadata_json = MetadataProcessor.to_json(metadata, id)
|
||||
|
||||
# Send metadata to frontend for display
|
||||
PromptServer.instance.send_sync("metadata_update", {
|
||||
"id": id,
|
||||
"metadata": metadata_json
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing metadata: {e}")
|
||||
|
||||
return ()
|
||||
@@ -1,23 +1,25 @@
|
||||
import logging
|
||||
import re
|
||||
from nodes import LoraLoader
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraManagerLoader:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
# "clip": ("CLIP",),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -26,74 +28,244 @@ class LoraManagerLoader:
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING)
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words")
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def load_loras(self, model, clip, text, **kwargs):
|
||||
def load_loras(self, model, text, **kwargs):
|
||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the provided path and strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Then process loras from kwargs
|
||||
if 'loras' in kwargs:
|
||||
for lora in kwargs['loras']:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
strength = float(lora['strength'])
|
||||
# Then process loras from kwargs with support for both old and new formats
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0]
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
# Apply the LoRA using the resolved path
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
||||
loaded_loras.append(f"{lora_name}: {strength}")
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
class LoraManagerTextLoader:
|
||||
NAME = "LoRA Text Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": (IO.STRING, {
|
||||
"defaultInput": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
"""Parse LoRA syntax from text input."""
|
||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
|
||||
loras = []
|
||||
for match in matches:
|
||||
lora_name = match[0]
|
||||
model_strength = float(match[1])
|
||||
clip_strength = float(match[2]) if match[2] else model_strength
|
||||
|
||||
loras.append({
|
||||
'name': lora_name,
|
||||
'model_strength': model_strength,
|
||||
'clip_strength': clip_strength
|
||||
})
|
||||
|
||||
return loras
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
"""Load LoRAs based on text syntax input."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Parse and process LoRAs from text syntax
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora['name']
|
||||
model_strength = lora['model_strength']
|
||||
clip_strength = lora['clip_strength']
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
return (model, clip, trigger_words_text)
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0].strip()
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
@@ -1,9 +1,11 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraStacker:
|
||||
NAME = "Lora Stacker (LoraManager)"
|
||||
@@ -15,6 +17,7 @@ class LoraStacker:
|
||||
"required": {
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -23,69 +26,61 @@ class LoraStacker:
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK", IO.STRING)
|
||||
RETURN_NAMES = ("LORA_STACK", "trigger_words")
|
||||
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
||||
FUNCTION = "stack_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def stack_loras(self, text, **kwargs):
|
||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||
stack = []
|
||||
active_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Process existing lora_stack if available
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
if lora_stack:
|
||||
if (lora_stack):
|
||||
stack.extend(lora_stack)
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
if 'loras' in kwargs:
|
||||
for lora in kwargs['loras']:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
|
||||
# Add to stack without loading
|
||||
stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
# Get clip strength - use model strength as default if not specified
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Add to stack without loading
|
||||
# replace '/' with os.sep to avoid different OS path format
|
||||
stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength))
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format active_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (stack, trigger_words_text)
|
||||
return (stack, trigger_words_text, active_loras_text)
|
||||
|
||||
445
py/nodes/save_image.py
Normal file
445
py/nodes/save_image.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
|
||||
class SaveImage:
|
||||
NAME = "Save Image (LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Save images with embedded generation metadata in compatible format"
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
self.compress_level = 4
|
||||
self.counter = 0
|
||||
|
||||
# Add pattern format regex for filename substitution
|
||||
pattern_format = re.compile(r"(%[^%]+%)")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"filename_prefix": ("STRING", {
|
||||
"default": "ComfyUI",
|
||||
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc."
|
||||
}),
|
||||
"file_format": (["png", "jpeg", "webp"], {
|
||||
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"lossless_webp": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss."
|
||||
}),
|
||||
"quality": ("INT", {
|
||||
"default": 100,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files."
|
||||
}),
|
||||
"embed_workflow": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats."
|
||||
}),
|
||||
"add_counter_to_filename": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images."
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "process_image"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def get_lora_hash(self, lora_name):
|
||||
"""Get the lora hash from cache"""
|
||||
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||
|
||||
# Use the new direct filename lookup method
|
||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
return None
|
||||
|
||||
def get_checkpoint_hash(self, checkpoint_path):
|
||||
"""Get the checkpoint hash from cache"""
|
||||
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Extract basename without extension
|
||||
checkpoint_name = os.path.basename(checkpoint_path)
|
||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||
|
||||
# Try direct filename lookup first
|
||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
return None
|
||||
|
||||
def format_metadata(self, metadata_dict):
|
||||
"""Format metadata in the requested format similar to userComment example"""
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
|
||||
# Helper function to only add parameter if value is not None
|
||||
def add_param_if_not_none(param_list, label, value):
|
||||
if value is not None:
|
||||
param_list.append(f"{label}: {value}")
|
||||
|
||||
# Extract the prompt and negative prompt
|
||||
prompt = metadata_dict.get('prompt', '')
|
||||
negative_prompt = metadata_dict.get('negative_prompt', '')
|
||||
|
||||
# Extract loras from the prompt if present
|
||||
loras_text = metadata_dict.get('loras', '')
|
||||
lora_hashes = {}
|
||||
|
||||
# If loras are found, add them on a new line after the prompt
|
||||
if loras_text:
|
||||
prompt_with_loras = f"{prompt}\n{loras_text}"
|
||||
|
||||
# Extract lora names from the format <lora:name:strength>
|
||||
lora_matches = re.findall(r'<lora:([^:]+):([^>]+)>', loras_text)
|
||||
|
||||
# Get hash for each lora
|
||||
for lora_name, strength in lora_matches:
|
||||
hash_value = self.get_lora_hash(lora_name)
|
||||
if hash_value:
|
||||
lora_hashes[lora_name] = hash_value
|
||||
else:
|
||||
prompt_with_loras = prompt
|
||||
|
||||
# Format the first part (prompt and loras)
|
||||
metadata_parts = [prompt_with_loras]
|
||||
|
||||
# Add negative prompt
|
||||
if negative_prompt:
|
||||
metadata_parts.append(f"Negative prompt: {negative_prompt}")
|
||||
|
||||
# Format the second part (generation parameters)
|
||||
params = []
|
||||
|
||||
# Add standard parameters in the correct order
|
||||
if 'steps' in metadata_dict:
|
||||
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
|
||||
|
||||
# Combine sampler and scheduler information
|
||||
sampler_name = None
|
||||
scheduler_name = None
|
||||
|
||||
if 'sampler' in metadata_dict:
|
||||
sampler = metadata_dict.get('sampler')
|
||||
# Convert ComfyUI sampler names to user-friendly names
|
||||
sampler_mapping = {
|
||||
'euler': 'Euler',
|
||||
'euler_ancestral': 'Euler a',
|
||||
'dpm_2': 'DPM2',
|
||||
'dpm_2_ancestral': 'DPM2 a',
|
||||
'heun': 'Heun',
|
||||
'dpm_fast': 'DPM fast',
|
||||
'dpm_adaptive': 'DPM adaptive',
|
||||
'lms': 'LMS',
|
||||
'dpmpp_2s_ancestral': 'DPM++ 2S a',
|
||||
'dpmpp_sde': 'DPM++ SDE',
|
||||
'dpmpp_sde_gpu': 'DPM++ SDE',
|
||||
'dpmpp_2m': 'DPM++ 2M',
|
||||
'dpmpp_2m_sde': 'DPM++ 2M SDE',
|
||||
'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE',
|
||||
'ddim': 'DDIM'
|
||||
}
|
||||
sampler_name = sampler_mapping.get(sampler, sampler)
|
||||
|
||||
if 'scheduler' in metadata_dict:
|
||||
scheduler = metadata_dict.get('scheduler')
|
||||
scheduler_mapping = {
|
||||
'normal': 'Simple',
|
||||
'karras': 'Karras',
|
||||
'exponential': 'Exponential',
|
||||
'sgm_uniform': 'SGM Uniform',
|
||||
'sgm_quadratic': 'SGM Quadratic'
|
||||
}
|
||||
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
||||
|
||||
# Add combined sampler and scheduler information
|
||||
if sampler_name:
|
||||
if scheduler_name:
|
||||
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
||||
else:
|
||||
params.append(f"Sampler: {sampler_name}")
|
||||
|
||||
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
||||
if 'guidance' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
|
||||
elif 'cfg_scale' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
|
||||
elif 'cfg' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
|
||||
|
||||
# Seed
|
||||
if 'seed' in metadata_dict:
|
||||
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
|
||||
|
||||
# Size
|
||||
if 'size' in metadata_dict:
|
||||
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
|
||||
|
||||
# Model info
|
||||
if 'checkpoint' in metadata_dict:
|
||||
# Ensure checkpoint is a string before processing
|
||||
checkpoint = metadata_dict.get('checkpoint')
|
||||
if checkpoint is not None:
|
||||
# Get model hash
|
||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||
|
||||
# Extract basename without path
|
||||
checkpoint_name = os.path.basename(checkpoint)
|
||||
# Remove extension if present
|
||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||
|
||||
# Add model hash if available
|
||||
if model_hash:
|
||||
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
||||
else:
|
||||
params.append(f"Model: {checkpoint_name}")
|
||||
|
||||
# Add LoRA hashes if available
|
||||
if lora_hashes:
|
||||
lora_hash_parts = []
|
||||
for lora_name, hash_value in lora_hashes.items():
|
||||
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
|
||||
|
||||
if lora_hash_parts:
|
||||
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"")
|
||||
|
||||
# Combine all parameters with commas
|
||||
metadata_parts.append(", ".join(params))
|
||||
|
||||
# Join all parts with a new line
|
||||
return "\n".join(metadata_parts)
|
||||
|
||||
# credit to nkchocoai
|
||||
# Add format_filename method to handle pattern substitution
|
||||
def format_filename(self, filename, metadata_dict):
|
||||
"""Format filename with metadata values"""
|
||||
if not metadata_dict:
|
||||
return filename
|
||||
|
||||
result = re.findall(self.pattern_format, filename)
|
||||
for segment in result:
|
||||
parts = segment.replace("%", "").split(":")
|
||||
key = parts[0]
|
||||
|
||||
if key == "seed" and 'seed' in metadata_dict:
|
||||
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
|
||||
elif key == "width" and 'size' in metadata_dict:
|
||||
size = metadata_dict.get('size', 'x')
|
||||
w = size.split('x')[0] if isinstance(size, str) else size[0]
|
||||
filename = filename.replace(segment, str(w))
|
||||
elif key == "height" and 'size' in metadata_dict:
|
||||
size = metadata_dict.get('size', 'x')
|
||||
h = size.split('x')[1] if isinstance(size, str) else size[1]
|
||||
filename = filename.replace(segment, str(h))
|
||||
elif key == "pprompt" and 'prompt' in metadata_dict:
|
||||
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
prompt = prompt[:length]
|
||||
filename = filename.replace(segment, prompt.strip())
|
||||
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
|
||||
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
prompt = prompt[:length]
|
||||
filename = filename.replace(segment, prompt.strip())
|
||||
elif key == "model" and 'checkpoint' in metadata_dict:
|
||||
model = metadata_dict.get('checkpoint', '')
|
||||
model = os.path.splitext(os.path.basename(model))[0]
|
||||
if len(parts) >= 2:
|
||||
length = int(parts[1])
|
||||
model = model[:length]
|
||||
filename = filename.replace(segment, model)
|
||||
elif key == "date":
|
||||
from datetime import datetime
|
||||
now = datetime.now()
|
||||
date_table = {
|
||||
"yyyy": f"{now.year:04d}",
|
||||
"yy": f"{now.year % 100:02d}",
|
||||
"MM": f"{now.month:02d}",
|
||||
"dd": f"{now.day:02d}",
|
||||
"hh": f"{now.hour:02d}",
|
||||
"mm": f"{now.minute:02d}",
|
||||
"ss": f"{now.second:02d}",
|
||||
}
|
||||
if len(parts) >= 2:
|
||||
date_format = parts[1]
|
||||
for k, v in date_table.items():
|
||||
date_format = date_format.replace(k, v)
|
||||
filename = filename.replace(segment, date_format)
|
||||
else:
|
||||
date_format = "yyyyMMddhhmmss"
|
||||
for k, v in date_table.items():
|
||||
date_format = date_format.replace(k, v)
|
||||
filename = filename.replace(segment, date_format)
|
||||
|
||||
return filename
|
||||
|
||||
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None,
|
||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
|
||||
"""Save images with metadata"""
|
||||
results = []
|
||||
|
||||
# Get metadata using the metadata collector
|
||||
raw_metadata = get_metadata()
|
||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||
|
||||
metadata = self.format_metadata(metadata_dict)
|
||||
|
||||
# Process filename_prefix with pattern substitution
|
||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||
|
||||
# Get initial save path info once for the batch
|
||||
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
if not os.path.exists(full_output_folder):
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
|
||||
# Process each image with incrementing counter
|
||||
for i, image in enumerate(images):
|
||||
# Convert the tensor image to numpy array
|
||||
img = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
|
||||
|
||||
# Generate filename with counter if needed
|
||||
base_filename = filename
|
||||
if add_counter_to_filename:
|
||||
# Use counter + i to ensure unique filenames for all images in batch
|
||||
current_counter = counter + i
|
||||
base_filename += f"_{current_counter:05}_"
|
||||
|
||||
# Set file extension and prepare saving parameters
|
||||
if file_format == "png":
|
||||
file = base_filename + ".png"
|
||||
file_extension = ".png"
|
||||
# Remove "optimize": True to match built-in node behavior
|
||||
save_kwargs = {"compress_level": self.compress_level}
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
elif file_format == "jpeg":
|
||||
file = base_filename + ".jpg"
|
||||
file_extension = ".jpg"
|
||||
save_kwargs = {"quality": quality, "optimize": True}
|
||||
elif file_format == "webp":
|
||||
file = base_filename + ".webp"
|
||||
file_extension = ".webp"
|
||||
# Add optimization param to control performance
|
||||
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
|
||||
|
||||
# Full save path
|
||||
file_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Save the image with metadata
|
||||
try:
|
||||
if file_format == "png":
|
||||
if metadata:
|
||||
pnginfo.add_text("parameters", metadata)
|
||||
if embed_workflow and extra_pnginfo is not None:
|
||||
workflow_json = json.dumps(extra_pnginfo["workflow"])
|
||||
pnginfo.add_text("workflow", workflow_json)
|
||||
save_kwargs["pnginfo"] = pnginfo
|
||||
img.save(file_path, format="PNG", **save_kwargs)
|
||||
elif file_format == "jpeg":
|
||||
# For JPEG, use piexif
|
||||
if metadata:
|
||||
try:
|
||||
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception as e:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
img.save(file_path, format="JPEG", **save_kwargs)
|
||||
elif file_format == "webp":
|
||||
try:
|
||||
# For WebP, use piexif for metadata
|
||||
exif_dict = {}
|
||||
|
||||
if metadata:
|
||||
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}
|
||||
|
||||
# Add workflow if needed
|
||||
if embed_workflow and extra_pnginfo is not None:
|
||||
workflow_json = json.dumps(extra_pnginfo["workflow"])
|
||||
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json}
|
||||
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception as e:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
|
||||
img.save(file_path, format="WEBP", **save_kwargs)
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving image: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
|
||||
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
|
||||
"""Process and save image with metadata"""
|
||||
# Make sure the output directory exists
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# If images is already a list or array of images, do nothing; otherwise, convert to list
|
||||
if isinstance(images, (list, np.ndarray)):
|
||||
pass
|
||||
else:
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
|
||||
# Save all images
|
||||
results = self.save_images(
|
||||
images,
|
||||
filename_prefix,
|
||||
file_format,
|
||||
id,
|
||||
prompt,
|
||||
extra_pnginfo,
|
||||
lossless_webp,
|
||||
quality,
|
||||
embed_workflow,
|
||||
add_counter_to_filename
|
||||
)
|
||||
|
||||
return (images,)
|
||||
@@ -2,6 +2,10 @@ import json
|
||||
import re
|
||||
from server import PromptServer # type: ignore
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerWordToggle:
|
||||
NAME = "TriggerWord Toggle (LoraManager)"
|
||||
@@ -12,11 +16,18 @@ class TriggerWordToggle:
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"group_mode": ("BOOLEAN", {"default": True}),
|
||||
"group_mode": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "When enabled, treats each group of trigger words as a single toggleable unit."
|
||||
}),
|
||||
"default_active": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Sets the default initial state (active or inactive) when trigger words are added."
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID
|
||||
"id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -24,22 +35,30 @@ class TriggerWordToggle:
|
||||
RETURN_NAMES = ("filtered_trigger_words",)
|
||||
FUNCTION = "process_trigger_words"
|
||||
|
||||
def process_trigger_words(self, id, group_mode, **kwargs):
|
||||
print("process_trigger_words kwargs: ", kwargs)
|
||||
trigger_words = kwargs.get("trigger_words", "")
|
||||
# Send trigger words to frontend
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": id,
|
||||
"message": trigger_words
|
||||
})
|
||||
def _get_toggle_data(self, kwargs, key='toggle_trigger_words'):
|
||||
"""Helper to extract data from either old or new kwargs format"""
|
||||
if key not in kwargs:
|
||||
return None
|
||||
|
||||
data = kwargs[key]
|
||||
# Handle new format: {'key': {'__value__': ...}}
|
||||
if isinstance(data, dict) and '__value__' in data:
|
||||
return data['__value__']
|
||||
# Handle old format: {'key': ...}
|
||||
else:
|
||||
return data
|
||||
|
||||
def process_trigger_words(self, id, group_mode, default_active, **kwargs):
|
||||
# Handle both old and new formats for trigger_words
|
||||
trigger_words_data = self._get_toggle_data(kwargs, 'orinalMessage')
|
||||
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||
|
||||
filtered_triggers = trigger_words
|
||||
|
||||
if 'toggle_trigger_words' in kwargs:
|
||||
# Get toggle data with support for both formats
|
||||
trigger_data = self._get_toggle_data(kwargs, 'toggle_trigger_words')
|
||||
if trigger_data:
|
||||
try:
|
||||
# Get trigger word toggle data
|
||||
trigger_data = kwargs['toggle_trigger_words']
|
||||
|
||||
# Convert to list if it's a JSON string
|
||||
if isinstance(trigger_data, str):
|
||||
trigger_data = json.loads(trigger_data)
|
||||
@@ -73,6 +92,6 @@ class TriggerWordToggle:
|
||||
filtered_triggers = ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing trigger words: {e}")
|
||||
logger.error(f"Error processing trigger words: {e}")
|
||||
|
||||
return (filtered_triggers,)
|
||||
@@ -30,4 +30,101 @@ class FlexibleOptionalInputType(dict):
|
||||
return True
|
||||
|
||||
|
||||
any_type = AnyType("*")
|
||||
any_type = AnyType("*")
|
||||
|
||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def get_loras_list(kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
|
||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
if filter_prefix and not k.startswith(filter_prefix):
|
||||
continue
|
||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
def to_diffusers(input_lora):
|
||||
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
|
||||
if isinstance(input_lora, str):
|
||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||
else:
|
||||
tensors = {k: v for k, v in input_lora.items()}
|
||||
|
||||
# Convert FP8 tensors to BF16
|
||||
for k, v in tensors.items():
|
||||
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensors[k] = v.to(torch.bfloat16)
|
||||
|
||||
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||
|
||||
return new_tensors
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
"""Load a Flux LoRA for Nunchaku model"""
|
||||
model_wrapper = model.model.diffusion_model
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
|
||||
# Get full path to the LoRA file
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
# Convert the LoRA to diffusers format
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
# Handle embedding adjustment if needed
|
||||
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||
assert new_in_channels % 4 == 0
|
||||
new_in_channels = new_in_channels // 4
|
||||
|
||||
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||
if old_in_channels < new_in_channels:
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
98
py/nodes/wanvideo_lora_select.py
Normal file
98
py/nodes/wanvideo_lora_select.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WanVideoLoraSelect:
|
||||
NAME = "WanVideo Lora Select (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras"
|
||||
|
||||
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
# Process existing prev_lora if available
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_loras:
|
||||
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
|
||||
|
||||
# Get blocks if available
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_from_widget = get_loras_list(kwargs)
|
||||
for lora in loras_from_widget:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Create lora item for WanVideo format
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_loras,
|
||||
}
|
||||
|
||||
# Add to list and collect active loras
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format trigger_words for output
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format active_loras for output
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
127
py/nodes/wanvideo_lora_select_from_text.py
Normal file
127
py/nodes/wanvideo_lora_select_from_text.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from comfy.comfy_types import IO
|
||||
import folder_paths
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import any_type
|
||||
import logging
|
||||
|
||||
# 初始化日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 定义新节点的类
|
||||
class WanVideoLoraSelectFromText:
|
||||
# 节点在UI中显示的名称
|
||||
NAME = "WanVideo Lora Select From Text (LoraManager)"
|
||||
# 节点所属的分类
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_lora": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"lora_syntax": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"defaultInput": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
|
||||
"optional": {
|
||||
"prev_lora": ("WANVIDLORA",),
|
||||
"blocks": ("BLOCKS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
|
||||
FUNCTION = "process_loras_from_syntax"
|
||||
|
||||
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
|
||||
text_to_process = lora_syntax
|
||||
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_lora:
|
||||
low_mem_load = False
|
||||
|
||||
parts = text_to_process.split('<lora:')
|
||||
for part in parts[1:]:
|
||||
end_index = part.find('>')
|
||||
if end_index == -1:
|
||||
continue
|
||||
|
||||
content = part[:end_index]
|
||||
lora_parts = content.split(':')
|
||||
|
||||
lora_name_raw = ""
|
||||
model_strength = 1.0
|
||||
clip_strength = 1.0
|
||||
|
||||
if len(lora_parts) == 2:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = model_strength
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strength for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
elif len(lora_parts) >= 3:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = float(lora_parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strengths for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_path, trigger_words = get_lora_info(lora_name_raw)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_lora,
|
||||
}
|
||||
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name_raw, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": "WanVideo Lora Select From Text (LoraManager)"
|
||||
}
|
||||
24
py/recipes/__init__.py
Normal file
24
py/recipes/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Recipe metadata parser package for ComfyUI-Lora-Manager."""
|
||||
|
||||
from .base import RecipeMetadataParser
|
||||
from .factory import RecipeParserFactory
|
||||
from .constants import GEN_PARAM_KEYS, VALID_LORA_TYPES
|
||||
from .parsers import (
|
||||
RecipeFormatParser,
|
||||
ComfyMetadataParser,
|
||||
MetaFormatParser,
|
||||
AutomaticMetadataParser,
|
||||
CivitaiApiMetadataParser
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'RecipeMetadataParser',
|
||||
'RecipeParserFactory',
|
||||
'GEN_PARAM_KEYS',
|
||||
'VALID_LORA_TYPES',
|
||||
'RecipeFormatParser',
|
||||
'ComfyMetadataParser',
|
||||
'MetaFormatParser',
|
||||
'AutomaticMetadataParser',
|
||||
'CivitaiApiMetadataParser'
|
||||
]
|
||||
184
py/recipes/base.py
Normal file
184
py/recipes/base.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Base classes for recipe parsers."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from ..config import config
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeMetadataParser(ABC):
|
||||
"""Interface for parsing recipe metadata from image user comments"""
|
||||
|
||||
METADATA_MARKER = None
|
||||
|
||||
@abstractmethod
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the metadata format"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse metadata from user comment and return structured recipe data
|
||||
|
||||
Args:
|
||||
user_comment: The EXIF UserComment string from the image
|
||||
recipe_scanner: Optional recipe scanner instance for local LoRA lookup
|
||||
civitai_client: Optional Civitai client for fetching model information
|
||||
|
||||
Returns:
|
||||
Dict containing parsed recipe data with standardized format
|
||||
"""
|
||||
pass
|
||||
|
||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Populate a lora entry with information from Civitai API response
|
||||
|
||||
Args:
|
||||
lora_entry: The lora entry to populate
|
||||
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
|
||||
recipe_scanner: Optional recipe scanner for local file lookup
|
||||
base_model_counts: Optional dict to track base model counts
|
||||
hash_value: Optional hash value to use if not available in civitai_info
|
||||
|
||||
Returns:
|
||||
The populated lora_entry dict if type is valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Unpack the tuple to get the actual data
|
||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
if not civitai_info or error_msg == "Model not found":
|
||||
# Model not found or deleted
|
||||
lora_entry['isDeleted'] = True
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
return lora_entry
|
||||
|
||||
# Get model type and validate
|
||||
model_type = civitai_info.get('model', {}).get('type', '').lower()
|
||||
lora_entry['type'] = model_type
|
||||
if model_type not in VALID_LORA_TYPES:
|
||||
logger.debug(f"Skipping non-LoRA model type: {model_type}")
|
||||
return None
|
||||
|
||||
# Check if this is an early access lora
|
||||
if civitai_info.get('earlyAccessEndsAt'):
|
||||
# Convert earlyAccessEndsAt to a human-readable date
|
||||
early_access_date = civitai_info.get('earlyAccessEndsAt', '')
|
||||
lora_entry['isEarlyAccess'] = True
|
||||
lora_entry['earlyAccessEndsAt'] = early_access_date
|
||||
|
||||
# Update model name if available
|
||||
if 'model' in civitai_info and 'name' in civitai_info['model']:
|
||||
lora_entry['name'] = civitai_info['model']['name']
|
||||
|
||||
lora_entry['id'] = civitai_info.get('id')
|
||||
lora_entry['modelId'] = civitai_info.get('modelId')
|
||||
|
||||
# Update version if available
|
||||
if 'name' in civitai_info:
|
||||
lora_entry['version'] = civitai_info.get('name', '')
|
||||
|
||||
# Get thumbnail URL from first image
|
||||
if 'images' in civitai_info and civitai_info['images']:
|
||||
lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
|
||||
|
||||
# Get base model
|
||||
current_base_model = civitai_info.get('baseModel', '')
|
||||
lora_entry['baseModel'] = current_base_model
|
||||
|
||||
# Update base model counts if tracking them
|
||||
if base_model_counts is not None and current_base_model:
|
||||
base_model_counts[current_base_model] = base_model_counts.get(current_base_model, 0) + 1
|
||||
|
||||
# Get download URL
|
||||
lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '')
|
||||
|
||||
# Process file information if available
|
||||
if 'files' in civitai_info:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in civitai_info.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
if model_file:
|
||||
# Get size
|
||||
lora_entry['size'] = model_file.get('sizeKB', 0) * 1024
|
||||
|
||||
# Get SHA256 hash
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256', hash_value)
|
||||
if sha256:
|
||||
lora_entry['hash'] = sha256.lower()
|
||||
|
||||
# Check if exists locally
|
||||
if recipe_scanner and lora_entry['hash']:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_hash(lora_entry['hash'])
|
||||
if exists_locally:
|
||||
try:
|
||||
local_path = lora_scanner.get_path_by_hash(lora_entry['hash'])
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['localPath'] = local_path
|
||||
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]
|
||||
|
||||
# Get thumbnail from local preview if available
|
||||
lora_cache = await lora_scanner.get_cached_data()
|
||||
lora_item = next((item for item in lora_cache.raw_data
|
||||
if item['sha256'].lower() == lora_entry['hash'].lower()), None)
|
||||
if lora_item and 'preview_url' in lora_item:
|
||||
lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting local lora path: {e}")
|
||||
else:
|
||||
# For missing LoRAs, get file_name from model_file.name
|
||||
file_name = model_file.get('name', '')
|
||||
lora_entry['file_name'] = os.path.splitext(file_name)[0] if file_name else ''
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error populating lora from Civitai info: {e}")
|
||||
|
||||
return lora_entry
|
||||
|
||||
async def populate_checkpoint_from_civitai(self, checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Populate checkpoint information from Civitai API response
|
||||
|
||||
Args:
|
||||
checkpoint: The checkpoint entry to populate
|
||||
civitai_info: The response from Civitai API
|
||||
|
||||
Returns:
|
||||
The populated checkpoint dict
|
||||
"""
|
||||
try:
|
||||
if civitai_info and civitai_info.get("error") != "Model not found":
|
||||
# Update model name if available
|
||||
if 'model' in civitai_info and 'name' in civitai_info['model']:
|
||||
checkpoint['name'] = civitai_info['model']['name']
|
||||
|
||||
# Update version if available
|
||||
if 'name' in civitai_info:
|
||||
checkpoint['version'] = civitai_info.get('name', '')
|
||||
|
||||
# Get thumbnail URL from first image
|
||||
if 'images' in civitai_info and civitai_info['images']:
|
||||
checkpoint['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
|
||||
|
||||
# Get base model
|
||||
checkpoint['baseModel'] = civitai_info.get('baseModel', '')
|
||||
|
||||
# Get download URL
|
||||
checkpoint['downloadUrl'] = civitai_info.get('downloadUrl', '')
|
||||
else:
|
||||
# Model not found or deleted
|
||||
checkpoint['isDeleted'] = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error populating checkpoint from Civitai info: {e}")
|
||||
|
||||
return checkpoint
|
||||
16
py/recipes/constants.py
Normal file
16
py/recipes/constants.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Constants used across recipe parsers."""
|
||||
|
||||
# Import VALID_LORA_TYPES from utils.constants
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
|
||||
# Constants for generation parameters
|
||||
GEN_PARAM_KEYS = [
|
||||
'prompt',
|
||||
'negative_prompt',
|
||||
'steps',
|
||||
'sampler',
|
||||
'cfg_scale',
|
||||
'seed',
|
||||
'size',
|
||||
'clip_skip',
|
||||
]
|
||||
64
py/recipes/factory.py
Normal file
64
py/recipes/factory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Factory for creating recipe metadata parsers."""
|
||||
|
||||
import logging
|
||||
from .parsers import (
|
||||
RecipeFormatParser,
|
||||
ComfyMetadataParser,
|
||||
MetaFormatParser,
|
||||
AutomaticMetadataParser,
|
||||
CivitaiApiMetadataParser
|
||||
)
|
||||
from .base import RecipeMetadataParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeParserFactory:
|
||||
"""Factory for creating recipe metadata parsers"""
|
||||
|
||||
@staticmethod
|
||||
def create_parser(metadata) -> RecipeMetadataParser:
|
||||
"""
|
||||
Create appropriate parser based on the metadata content
|
||||
|
||||
Args:
|
||||
metadata: The metadata from the image (dict or str)
|
||||
|
||||
Returns:
|
||||
Appropriate RecipeMetadataParser implementation
|
||||
"""
|
||||
# First, try CivitaiApiMetadataParser for dict input
|
||||
if isinstance(metadata, dict):
|
||||
try:
|
||||
if CivitaiApiMetadataParser().is_metadata_matching(metadata):
|
||||
return CivitaiApiMetadataParser()
|
||||
except Exception as e:
|
||||
logger.debug(f"CivitaiApiMetadataParser check failed: {e}")
|
||||
pass
|
||||
|
||||
# Convert dict to string for other parsers that expect string input
|
||||
try:
|
||||
import json
|
||||
metadata_str = json.dumps(metadata)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to convert dict to JSON string: {e}")
|
||||
return None
|
||||
else:
|
||||
metadata_str = metadata
|
||||
|
||||
# Try ComfyMetadataParser which requires valid JSON
|
||||
try:
|
||||
if ComfyMetadataParser().is_metadata_matching(metadata_str):
|
||||
return ComfyMetadataParser()
|
||||
except Exception:
|
||||
# If JSON parsing fails, move on to other parsers
|
||||
pass
|
||||
|
||||
# Check other parsers that expect string input
|
||||
if RecipeFormatParser().is_metadata_matching(metadata_str):
|
||||
return RecipeFormatParser()
|
||||
elif AutomaticMetadataParser().is_metadata_matching(metadata_str):
|
||||
return AutomaticMetadataParser()
|
||||
elif MetaFormatParser().is_metadata_matching(metadata_str):
|
||||
return MetaFormatParser()
|
||||
else:
|
||||
return None
|
||||
15
py/recipes/parsers/__init__.py
Normal file
15
py/recipes/parsers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Recipe parsers package."""
|
||||
|
||||
from .recipe_format import RecipeFormatParser
|
||||
from .comfy import ComfyMetadataParser
|
||||
from .meta_format import MetaFormatParser
|
||||
from .automatic import AutomaticMetadataParser
|
||||
from .civitai_image import CivitaiApiMetadataParser
|
||||
|
||||
__all__ = [
|
||||
'RecipeFormatParser',
|
||||
'ComfyMetadataParser',
|
||||
'MetaFormatParser',
|
||||
'AutomaticMetadataParser',
|
||||
'CivitaiApiMetadataParser',
|
||||
]
|
||||
325
py/recipes/parsers/automatic.py
Normal file
325
py/recipes/parsers/automatic.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Parser for Automatic1111 metadata format."""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
"""Parser for Automatic1111 metadata format"""
|
||||
|
||||
METADATA_MARKER = r"Steps: \d+"
|
||||
|
||||
# Regular expressions for extracting specific metadata
|
||||
HASHES_REGEX = r', Hashes:\s*({[^}]+})'
|
||||
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
||||
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
||||
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
|
||||
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
||||
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the Automatic1111 format"""
|
||||
return re.search(self.METADATA_MARKER, user_comment) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Automatic1111 format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Split on Negative prompt if it exists
|
||||
if "Negative prompt:" in user_comment:
|
||||
parts = user_comment.split('Negative prompt:', 1)
|
||||
prompt = parts[0].strip()
|
||||
negative_and_params = parts[1] if len(parts) > 1 else ""
|
||||
else:
|
||||
# No negative prompt section
|
||||
param_start = re.search(self.METADATA_MARKER, user_comment)
|
||||
if param_start:
|
||||
prompt = user_comment[:param_start.start()].strip()
|
||||
negative_and_params = user_comment[param_start.start():]
|
||||
else:
|
||||
prompt = user_comment.strip()
|
||||
negative_and_params = ""
|
||||
|
||||
# Initialize metadata
|
||||
metadata = {
|
||||
"prompt": prompt,
|
||||
"loras": []
|
||||
}
|
||||
|
||||
# Extract negative prompt and parameters
|
||||
if negative_and_params:
|
||||
# If we split on "Negative prompt:", check for params section
|
||||
if "Negative prompt:" in user_comment:
|
||||
param_start = re.search(r'Steps: ', negative_and_params)
|
||||
if param_start:
|
||||
neg_prompt = negative_and_params[:param_start.start()].strip()
|
||||
metadata["negative_prompt"] = neg_prompt
|
||||
params_section = negative_and_params[param_start.start():]
|
||||
else:
|
||||
metadata["negative_prompt"] = negative_and_params.strip()
|
||||
params_section = ""
|
||||
else:
|
||||
# No negative prompt, entire section is params
|
||||
params_section = negative_and_params
|
||||
|
||||
# Extract generation parameters
|
||||
if params_section:
|
||||
# Extract Civitai resources
|
||||
civitai_resources_match = re.search(self.CIVITAI_RESOURCES_REGEX, params_section)
|
||||
if civitai_resources_match:
|
||||
try:
|
||||
civitai_resources = json.loads(civitai_resources_match.group(1))
|
||||
metadata["civitai_resources"] = civitai_resources
|
||||
params_section = params_section.replace(civitai_resources_match.group(0), '')
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error parsing Civitai resources JSON")
|
||||
|
||||
# Extract Hashes
|
||||
hashes_match = re.search(self.HASHES_REGEX, params_section)
|
||||
if hashes_match:
|
||||
try:
|
||||
hashes = json.loads(hashes_match.group(1))
|
||||
# Process hash keys
|
||||
processed_hashes = {}
|
||||
for key, value in hashes.items():
|
||||
# Convert Model: or LORA: prefix to lowercase if present
|
||||
if ':' in key:
|
||||
prefix, name = key.split(':', 1)
|
||||
prefix = prefix.lower()
|
||||
else:
|
||||
prefix = ''
|
||||
name = key
|
||||
|
||||
# Clean up the name part
|
||||
if '/' in name:
|
||||
name = name.split('/')[-1] # Get last part after /
|
||||
if '.safetensors' in name:
|
||||
name = name.split('.safetensors')[0] # Remove .safetensors
|
||||
|
||||
# Reconstruct the key
|
||||
new_key = f"{prefix}:{name}" if prefix else name
|
||||
processed_hashes[new_key] = value
|
||||
|
||||
metadata["hashes"] = processed_hashes
|
||||
# Remove hashes from params section to not interfere with other parsing
|
||||
params_section = params_section.replace(hashes_match.group(0), '')
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error parsing hashes JSON")
|
||||
|
||||
# Extract Lora hashes in alternative format
|
||||
lora_hashes_match = re.search(self.LORA_HASHES_REGEX, params_section)
|
||||
if not hashes_match and lora_hashes_match:
|
||||
try:
|
||||
lora_hashes_str = lora_hashes_match.group(1)
|
||||
lora_hash_entries = lora_hashes_str.split(', ')
|
||||
|
||||
# Initialize hashes dict if it doesn't exist
|
||||
if "hashes" not in metadata:
|
||||
metadata["hashes"] = {}
|
||||
|
||||
# Parse each lora hash entry (format: "name: hash")
|
||||
for entry in lora_hash_entries:
|
||||
if ': ' in entry:
|
||||
lora_name, lora_hash = entry.split(': ', 1)
|
||||
# Add as lora type in the same format as regular hashes
|
||||
metadata["hashes"][f"lora:{lora_name}"] = lora_hash.strip()
|
||||
|
||||
# Remove lora hashes from params section
|
||||
params_section = params_section.replace(lora_hashes_match.group(0), '')
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Lora hashes: {e}")
|
||||
|
||||
# Extract basic parameters
|
||||
param_pattern = r'([A-Za-z\s]+): ([^,]+)'
|
||||
params = re.findall(param_pattern, params_section)
|
||||
gen_params = {}
|
||||
|
||||
for key, value in params:
|
||||
clean_key = key.strip().lower().replace(' ', '_')
|
||||
|
||||
# Skip if not in recognized gen param keys
|
||||
if clean_key not in GEN_PARAM_KEYS:
|
||||
continue
|
||||
|
||||
# Convert numeric values
|
||||
if clean_key in ['steps', 'seed']:
|
||||
try:
|
||||
gen_params[clean_key] = int(value.strip())
|
||||
except ValueError:
|
||||
gen_params[clean_key] = value.strip()
|
||||
elif clean_key in ['cfg_scale']:
|
||||
try:
|
||||
gen_params[clean_key] = float(value.strip())
|
||||
except ValueError:
|
||||
gen_params[clean_key] = value.strip()
|
||||
else:
|
||||
gen_params[clean_key] = value.strip()
|
||||
|
||||
# Extract size if available and add to gen_params if a recognized key
|
||||
size_match = re.search(r'Size: (\d+)x(\d+)', params_section)
|
||||
if size_match and 'size' in GEN_PARAM_KEYS:
|
||||
width, height = size_match.groups()
|
||||
gen_params['size'] = f"{width}x{height}"
|
||||
|
||||
# Add prompt and negative_prompt to gen_params if they're in GEN_PARAM_KEYS
|
||||
if 'prompt' in GEN_PARAM_KEYS and 'prompt' in metadata:
|
||||
gen_params['prompt'] = metadata['prompt']
|
||||
if 'negative_prompt' in GEN_PARAM_KEYS and 'negative_prompt' in metadata:
|
||||
gen_params['negative_prompt'] = metadata['negative_prompt']
|
||||
|
||||
metadata["gen_params"] = gen_params
|
||||
|
||||
# Extract LoRA information
|
||||
loras = []
|
||||
base_model_counts = {}
|
||||
|
||||
# First use Civitai resources if available (more reliable source)
|
||||
if metadata.get("civitai_resources"):
|
||||
for resource in metadata.get("civitai_resources", []):
|
||||
# --- Added: Parse 'air' field if present ---
|
||||
air = resource.get("air")
|
||||
if air:
|
||||
# Format: urn:air:sdxl:lora:civitai:1221007@1375651
|
||||
# Or: urn:air:sdxl:checkpoint:civitai:623891@2019115
|
||||
air_pattern = r"urn:air:[^:]+:(?P<type>[^:]+):civitai:(?P<modelId>\d+)@(?P<modelVersionId>\d+)"
|
||||
air_match = re.match(air_pattern, air)
|
||||
if air_match:
|
||||
air_type = air_match.group("type")
|
||||
air_modelId = int(air_match.group("modelId"))
|
||||
air_modelVersionId = int(air_match.group("modelVersionId"))
|
||||
# checkpoint/lycoris/lora/hypernet
|
||||
resource["type"] = air_type
|
||||
resource["modelId"] = air_modelId
|
||||
resource["modelVersionId"] = air_modelVersionId
|
||||
# --- End added ---
|
||||
|
||||
if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"):
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", resource.get("versionName", "")),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get additional info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_version_info(resource.get("modelVersionId"))
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA {lora_entry['name']}: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# If no LoRAs from Civitai resources or to supplement, extract from metadata["hashes"]
|
||||
if not loras or len(loras) == 0:
|
||||
# Extract lora weights from extranet tags in prompt (for later use)
|
||||
lora_weights = {}
|
||||
lora_matches = re.findall(self.EXTRANETS_REGEX, prompt)
|
||||
for lora_type, lora_name, lora_weight in lora_matches:
|
||||
key = f"{lora_type}:{lora_name}"
|
||||
lora_weights[key] = round(float(lora_weight), 2)
|
||||
|
||||
# Use hashes from metadata as the primary source
|
||||
if metadata.get("hashes"):
|
||||
for hash_key, lora_hash in metadata.get("hashes", {}).items():
|
||||
# Only process lora or hypernet types
|
||||
if not hash_key.startswith(("lora:", "hypernet:")):
|
||||
continue
|
||||
|
||||
lora_type, lora_name = hash_key.split(':', 1)
|
||||
|
||||
# Get weight from extranet tags if available, else default to 1.0
|
||||
weight = lora_weights.get(hash_key, 1.0)
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': lora_type, # 'lora' or 'hypernet'
|
||||
'weight': weight,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
if lora_hash:
|
||||
# If we have hash, use it for lookup
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
else:
|
||||
civitai_info = None
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA {lora_name}: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Try to get base model from resources or make educated guess
|
||||
base_model = None
|
||||
if base_model_counts:
|
||||
# Use the most common base model from the loras
|
||||
base_model = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
# Prepare final result structure
|
||||
# Make sure gen_params only contains recognized keys
|
||||
filtered_gen_params = {}
|
||||
for key in GEN_PARAM_KEYS:
|
||||
if key in metadata.get("gen_params", {}):
|
||||
filtered_gen_params[key] = metadata["gen_params"][key]
|
||||
|
||||
result = {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'gen_params': filtered_gen_params,
|
||||
'from_automatic_metadata': True
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Automatic1111 metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
356
py/recipes/parsers/civitai_image.py
Normal file
356
py/recipes/parsers/civitai_image.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Parser for Civitai image metadata format."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Union
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
"""Parser for Civitai image metadata format"""
|
||||
|
||||
def is_metadata_matching(self, metadata) -> bool:
|
||||
"""Check if the metadata matches the Civitai image metadata format
|
||||
|
||||
Args:
|
||||
metadata: The metadata from the image (dict)
|
||||
|
||||
Returns:
|
||||
bool: True if this parser can handle the metadata
|
||||
"""
|
||||
if not metadata or not isinstance(metadata, dict):
|
||||
return False
|
||||
|
||||
# Check for key markers specific to Civitai image metadata
|
||||
return any([
|
||||
"resources" in metadata,
|
||||
"civitaiResources" in metadata,
|
||||
"additionalResources" in metadata
|
||||
])
|
||||
|
||||
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Civitai image format
|
||||
|
||||
Args:
|
||||
metadata: The metadata from the image (dict)
|
||||
recipe_scanner: Optional recipe scanner service
|
||||
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
|
||||
|
||||
Returns:
|
||||
Dict containing parsed recipe data
|
||||
"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Initialize result structure
|
||||
result = {
|
||||
'base_model': None,
|
||||
'loras': [],
|
||||
'gen_params': {},
|
||||
'from_civitai_image': True
|
||||
}
|
||||
|
||||
# Track already added LoRAs to prevent duplicates
|
||||
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
|
||||
|
||||
# Extract hash information from hashes field for LoRA matching
|
||||
lora_hashes = {}
|
||||
if "hashes" in metadata and isinstance(metadata["hashes"], dict):
|
||||
for key, hash_value in metadata["hashes"].items():
|
||||
if key.startswith("LORA:"):
|
||||
lora_name = key.replace("LORA:", "")
|
||||
lora_hashes[lora_name] = hash_value
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
if "prompt" in metadata:
|
||||
result["gen_params"]["prompt"] = metadata["prompt"]
|
||||
|
||||
if "negativePrompt" in metadata:
|
||||
result["gen_params"]["negative_prompt"] = metadata["negativePrompt"]
|
||||
|
||||
# Extract other generation parameters
|
||||
param_mapping = {
|
||||
"steps": "steps",
|
||||
"sampler": "sampler",
|
||||
"cfgScale": "cfg_scale",
|
||||
"seed": "seed",
|
||||
"Size": "size",
|
||||
"clipSkip": "clip_skip",
|
||||
}
|
||||
|
||||
for civitai_key, our_key in param_mapping.items():
|
||||
if civitai_key in metadata and our_key in GEN_PARAM_KEYS:
|
||||
result["gen_params"][our_key] = metadata[civitai_key]
|
||||
|
||||
# Extract base model information - directly if available
|
||||
if "baseModel" in metadata:
|
||||
result["base_model"] = metadata["baseModel"]
|
||||
elif "Model hash" in metadata and metadata_provider:
|
||||
model_hash = metadata["Model hash"]
|
||||
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
||||
# Try to find base model in resources
|
||||
for resource in metadata.get("resources", []):
|
||||
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
||||
# This is likely the checkpoint model
|
||||
if metadata_provider and resource.get("hash"):
|
||||
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
|
||||
base_model_counts = {}
|
||||
|
||||
# Process standard resources array
|
||||
if "resources" in metadata and isinstance(metadata["resources"], list):
|
||||
for resource in metadata["resources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type", "lora") == "lora":
|
||||
lora_hash = resource.get("hash", "")
|
||||
|
||||
# Try to get hash from the hashes field if not present in resource
|
||||
if not lora_hash and resource.get("name"):
|
||||
lora_hash = lora_hashes.get(resource["name"], "")
|
||||
|
||||
# Skip LoRAs without proper identification (hash or modelVersionId)
|
||||
if not lora_hash and not resource.get("modelVersionId"):
|
||||
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
|
||||
continue
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': resource.get("name", "Unknown LoRA"),
|
||||
'type': "lora",
|
||||
'weight': float(resource.get("weight", 1.0)),
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("name", "Unknown"),
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process civitaiResources array
|
||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
||||
for resource in metadata["civitaiResources"]:
|
||||
# Get unique identifier for deduplication
|
||||
version_id = str(resource.get("modelVersionId", ""))
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id and version_id in added_loras:
|
||||
continue
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
||||
|
||||
# Track this LoRA in our deduplication dict
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Skip resources that aren't LoRAs or LyCORIS
|
||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||
continue
|
||||
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
|
||||
# Extract ID from URN format if available
|
||||
version_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
version_id = parts[1]
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a version ID and metadata provider, try to get more info
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info with the version ID
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
|
||||
lora_index = 0
|
||||
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
|
||||
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
|
||||
lora_name = metadata[f"Lora_{lora_index} Model name"]
|
||||
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
lora_index += 1
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': "lora",
|
||||
'weight': lora_strength_model,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
lora_index += 1
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
lora_index += 1
|
||||
|
||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||
if not result["base_model"] and base_model_counts:
|
||||
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
220
py/recipes/parsers/comfy.py
Normal file
220
py/recipes/parsers/comfy.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Parser for ComfyUI metadata format."""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ComfyMetadataParser(RecipeMetadataParser):
|
||||
"""Parser for Civitai ComfyUI metadata JSON format"""
|
||||
|
||||
METADATA_MARKER = r"class_type"
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the ComfyUI metadata format"""
|
||||
try:
|
||||
data = json.loads(user_comment)
|
||||
# Check if it contains class_type nodes typical of ComfyUI workflow
|
||||
return isinstance(data, dict) and any(isinstance(v, dict) and 'class_type' in v for v in data.values())
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return False
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Civitai ComfyUI metadata format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
data = json.loads(user_comment)
|
||||
loras = []
|
||||
|
||||
# Find all LoraLoader nodes
|
||||
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
|
||||
|
||||
if not lora_nodes:
|
||||
return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []}
|
||||
|
||||
# Process each LoraLoader node
|
||||
for node_id, node in lora_nodes.items():
|
||||
if 'inputs' not in node or 'lora_name' not in node['inputs']:
|
||||
continue
|
||||
|
||||
lora_name = node['inputs'].get('lora_name', '')
|
||||
|
||||
# Parse the URN to extract model ID and version ID
|
||||
# Format: "urn:air:sdxl:lora:civitai:1107767@1253442"
|
||||
lora_id_match = re.search(r'civitai:(\d+)@(\d+)', lora_name)
|
||||
if not lora_id_match:
|
||||
continue
|
||||
|
||||
model_id = lora_id_match.group(1)
|
||||
model_version_id = lora_id_match.group(2)
|
||||
|
||||
# Get strength from node inputs
|
||||
weight = node['inputs'].get('strength_model', 1.0)
|
||||
|
||||
# Initialize lora entry with default values
|
||||
lora_entry = {
|
||||
'id': model_version_id,
|
||||
'modelId': model_id,
|
||||
'name': f"Lora {model_id}", # Default name
|
||||
'version': '',
|
||||
'type': 'lora',
|
||||
'weight': weight,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': '',
|
||||
'hash': '',
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get additional info from Civitai if metadata provider is available
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(model_version_id)
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info_tuple,
|
||||
recipe_scanner
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Find checkpoint info
|
||||
checkpoint_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'CheckpointLoaderSimple'}
|
||||
checkpoint = None
|
||||
checkpoint_id = None
|
||||
checkpoint_version_id = None
|
||||
|
||||
if checkpoint_nodes:
|
||||
# Get the first checkpoint node
|
||||
checkpoint_node = next(iter(checkpoint_nodes.values()))
|
||||
if 'inputs' in checkpoint_node and 'ckpt_name' in checkpoint_node['inputs']:
|
||||
checkpoint_name = checkpoint_node['inputs']['ckpt_name']
|
||||
# Parse checkpoint URN
|
||||
checkpoint_match = re.search(r'civitai:(\d+)@(\d+)', checkpoint_name)
|
||||
if checkpoint_match:
|
||||
checkpoint_id = checkpoint_match.group(1)
|
||||
checkpoint_version_id = checkpoint_match.group(2)
|
||||
checkpoint = {
|
||||
'id': checkpoint_version_id,
|
||||
'modelId': checkpoint_id,
|
||||
'name': f"Checkpoint {checkpoint_id}",
|
||||
'version': '',
|
||||
'type': 'checkpoint'
|
||||
}
|
||||
|
||||
# Get additional checkpoint info from Civitai
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(checkpoint_version_id)
|
||||
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
# Populate checkpoint with Civitai info
|
||||
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for checkpoint: {e}")
|
||||
|
||||
# Extract generation parameters
|
||||
gen_params = {}
|
||||
|
||||
# First try to get from extraMetadata
|
||||
if 'extraMetadata' in data:
|
||||
try:
|
||||
# extraMetadata is a JSON string that needs to be parsed
|
||||
extra_metadata = json.loads(data['extraMetadata'])
|
||||
|
||||
# Map fields from extraMetadata to our standard format
|
||||
mapping = {
|
||||
'prompt': 'prompt',
|
||||
'negativePrompt': 'negative_prompt',
|
||||
'steps': 'steps',
|
||||
'sampler': 'sampler',
|
||||
'cfgScale': 'cfg_scale',
|
||||
'seed': 'seed'
|
||||
}
|
||||
|
||||
for src_key, dest_key in mapping.items():
|
||||
if src_key in extra_metadata:
|
||||
gen_params[dest_key] = extra_metadata[src_key]
|
||||
|
||||
# If size info is available, format as "width x height"
|
||||
if 'width' in extra_metadata and 'height' in extra_metadata:
|
||||
gen_params['size'] = f"{extra_metadata['width']}x{extra_metadata['height']}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing extraMetadata: {e}")
|
||||
|
||||
# If extraMetadata doesn't have all the info, try to get from nodes
|
||||
if not gen_params or len(gen_params) < 3: # At least we want prompt, negative_prompt, and steps
|
||||
# Find positive prompt node
|
||||
positive_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and
|
||||
v.get('class_type', '').endswith('CLIPTextEncode') and
|
||||
v.get('_meta', {}).get('title') == 'Positive'}
|
||||
|
||||
if positive_nodes:
|
||||
positive_node = next(iter(positive_nodes.values()))
|
||||
if 'inputs' in positive_node and 'text' in positive_node['inputs']:
|
||||
gen_params['prompt'] = positive_node['inputs']['text']
|
||||
|
||||
# Find negative prompt node
|
||||
negative_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and
|
||||
v.get('class_type', '').endswith('CLIPTextEncode') and
|
||||
v.get('_meta', {}).get('title') == 'Negative'}
|
||||
|
||||
if negative_nodes:
|
||||
negative_node = next(iter(negative_nodes.values()))
|
||||
if 'inputs' in negative_node and 'text' in negative_node['inputs']:
|
||||
gen_params['negative_prompt'] = negative_node['inputs']['text']
|
||||
|
||||
# Find KSampler node for other parameters
|
||||
ksampler_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'KSampler'}
|
||||
|
||||
if ksampler_nodes:
|
||||
ksampler_node = next(iter(ksampler_nodes.values()))
|
||||
if 'inputs' in ksampler_node:
|
||||
inputs = ksampler_node['inputs']
|
||||
if 'sampler_name' in inputs:
|
||||
gen_params['sampler'] = inputs['sampler_name']
|
||||
if 'steps' in inputs:
|
||||
gen_params['steps'] = inputs['steps']
|
||||
if 'cfg' in inputs:
|
||||
gen_params['cfg_scale'] = inputs['cfg']
|
||||
if 'seed' in inputs:
|
||||
gen_params['seed'] = inputs['seed']
|
||||
|
||||
# Determine base model from loras info
|
||||
base_model = None
|
||||
if loras:
|
||||
# Use the most common base model from loras
|
||||
base_models = [lora['baseModel'] for lora in loras if lora.get('baseModel')]
|
||||
if base_models:
|
||||
from collections import Counter
|
||||
base_model_counts = Counter(base_models)
|
||||
base_model = base_model_counts.most_common(1)[0][0]
|
||||
|
||||
return {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'checkpoint': checkpoint,
|
||||
'gen_params': gen_params,
|
||||
'from_comfy_metadata': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing ComfyUI metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
178
py/recipes/parsers/meta_format.py
Normal file
178
py/recipes/parsers/meta_format.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Parser for meta format (Lora_N Model hash) metadata."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetaFormatParser(RecipeMetadataParser):
|
||||
"""Parser for images with meta format metadata (Lora_N Model hash format)"""
|
||||
|
||||
METADATA_MARKER = r'Lora_\d+ Model hash:'
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the metadata format"""
|
||||
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from images with meta format metadata (Lora_N Model hash format)"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
parts = user_comment.split('Negative prompt:', 1)
|
||||
prompt = parts[0].strip()
|
||||
|
||||
# Initialize metadata
|
||||
metadata = {"prompt": prompt, "loras": []}
|
||||
|
||||
# Extract negative prompt and parameters if available
|
||||
if len(parts) > 1:
|
||||
negative_and_params = parts[1]
|
||||
|
||||
# Extract negative prompt - everything until the first parameter (usually "Steps:")
|
||||
param_start = re.search(r'([A-Za-z]+): ', negative_and_params)
|
||||
if param_start:
|
||||
neg_prompt = negative_and_params[:param_start.start()].strip()
|
||||
metadata["negative_prompt"] = neg_prompt
|
||||
params_section = negative_and_params[param_start.start():]
|
||||
else:
|
||||
params_section = negative_and_params
|
||||
|
||||
# Extract key-value parameters (Steps, Sampler, Seed, etc.)
|
||||
param_pattern = r'([A-Za-z_0-9 ]+): ([^,]+)'
|
||||
params = re.findall(param_pattern, params_section)
|
||||
for key, value in params:
|
||||
clean_key = key.strip().lower().replace(' ', '_')
|
||||
metadata[clean_key] = value.strip()
|
||||
|
||||
# Extract LoRA information
|
||||
# Pattern to match lora entries: Lora_0 Model name: ArtVador I.safetensors, Lora_0 Model hash: 08f7133a58, etc.
|
||||
lora_pattern = r'Lora_(\d+) Model name: ([^,]+), Lora_\1 Model hash: ([^,]+), Lora_\1 Strength model: ([^,]+), Lora_\1 Strength clip: ([^,]+)'
|
||||
lora_matches = re.findall(lora_pattern, user_comment)
|
||||
|
||||
# If the regular pattern doesn't match, try a more flexible approach
|
||||
if not lora_matches:
|
||||
# First find all Lora indices
|
||||
lora_indices = set(re.findall(r'Lora_(\d+)', user_comment))
|
||||
|
||||
# For each index, extract the information
|
||||
for idx in lora_indices:
|
||||
lora_info = {}
|
||||
|
||||
# Extract model name
|
||||
name_match = re.search(f'Lora_{idx} Model name: ([^,]+)', user_comment)
|
||||
if name_match:
|
||||
lora_info['name'] = name_match.group(1).strip()
|
||||
|
||||
# Extract model hash
|
||||
hash_match = re.search(f'Lora_{idx} Model hash: ([^,]+)', user_comment)
|
||||
if hash_match:
|
||||
lora_info['hash'] = hash_match.group(1).strip()
|
||||
|
||||
# Extract strength model
|
||||
strength_model_match = re.search(f'Lora_{idx} Strength model: ([^,]+)', user_comment)
|
||||
if strength_model_match:
|
||||
lora_info['strength_model'] = float(strength_model_match.group(1).strip())
|
||||
|
||||
# Extract strength clip
|
||||
strength_clip_match = re.search(f'Lora_{idx} Strength clip: ([^,]+)', user_comment)
|
||||
if strength_clip_match:
|
||||
lora_info['strength_clip'] = float(strength_clip_match.group(1).strip())
|
||||
|
||||
# Only add if we have at least name and hash
|
||||
if 'name' in lora_info and 'hash' in lora_info:
|
||||
lora_matches.append((idx, lora_info['name'], lora_info['hash'],
|
||||
str(lora_info.get('strength_model', 1.0)),
|
||||
str(lora_info.get('strength_clip', 1.0))))
|
||||
|
||||
# Process LoRAs
|
||||
base_model_counts = {}
|
||||
loras = []
|
||||
|
||||
for match in lora_matches:
|
||||
if len(match) == 5: # Regular pattern match
|
||||
idx, name, hash_value, strength_model, strength_clip = match
|
||||
else: # Flexible approach match
|
||||
continue # Should not happen now
|
||||
|
||||
# Clean up the values
|
||||
name = name.strip()
|
||||
if name.endswith('.safetensors'):
|
||||
name = name[:-12] # Remove .safetensors extension
|
||||
|
||||
hash_value = hash_value.strip()
|
||||
weight = float(strength_model) # Use model strength as weight
|
||||
|
||||
# Initialize lora entry with default values
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': 'lora',
|
||||
'weight': weight,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'hash': hash_value,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get info from Civitai by hash if available
|
||||
if metadata_provider and hash_value:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(hash_value)
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
hash_value
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}")
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
# Extract model information
|
||||
model = None
|
||||
if 'model' in metadata:
|
||||
model = metadata['model']
|
||||
|
||||
# Set base_model to the most common one from civitai_info
|
||||
base_model = None
|
||||
if base_model_counts:
|
||||
base_model = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
# Extract generation parameters for recipe metadata
|
||||
gen_params = {}
|
||||
for key in GEN_PARAM_KEYS:
|
||||
if key in metadata:
|
||||
gen_params[key] = metadata.get(key, '')
|
||||
|
||||
# Try to extract size information if available
|
||||
if 'width' in metadata and 'height' in metadata:
|
||||
gen_params['size'] = f"{metadata['width']}x{metadata['height']}"
|
||||
|
||||
return {
|
||||
'base_model': base_model,
|
||||
'loras': loras,
|
||||
'gen_params': gen_params,
|
||||
'raw_metadata': metadata,
|
||||
'from_meta_format': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing meta format metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
118
py/recipes/parsers/recipe_format.py
Normal file
118
py/recipes/parsers/recipe_format.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Parser for dedicated recipe metadata format."""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from ...config import config
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeFormatParser(RecipeMetadataParser):
|
||||
"""Parser for images with dedicated recipe metadata format"""
|
||||
|
||||
# Regular expression pattern for extracting recipe metadata
|
||||
METADATA_MARKER = r'Recipe metadata: (\{.*\})'
|
||||
|
||||
def is_metadata_matching(self, user_comment: str) -> bool:
|
||||
"""Check if the user comment matches the metadata format"""
|
||||
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from images with dedicated recipe metadata format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Extract recipe metadata from user comment
|
||||
try:
|
||||
# Look for recipe metadata section
|
||||
recipe_match = re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL)
|
||||
if not recipe_match:
|
||||
recipe_metadata = None
|
||||
else:
|
||||
recipe_json = recipe_match.group(1)
|
||||
recipe_metadata = json.loads(recipe_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting recipe metadata: {e}")
|
||||
recipe_metadata = None
|
||||
if not recipe_metadata:
|
||||
return {"error": "No recipe metadata found", "loras": []}
|
||||
|
||||
# Process the recipe metadata
|
||||
loras = []
|
||||
for lora in recipe_metadata.get('loras', []):
|
||||
# Convert recipe lora format to frontend format
|
||||
lora_entry = {
|
||||
'id': int(lora.get('modelVersionId', 0)),
|
||||
'name': lora.get('modelName', ''),
|
||||
'version': lora.get('modelVersionName', ''),
|
||||
'type': 'lora',
|
||||
'weight': lora.get('strength', 1.0),
|
||||
'file_name': lora.get('file_name', ''),
|
||||
'hash': lora.get('hash', '')
|
||||
}
|
||||
|
||||
# Check if this LoRA exists locally by SHA256 hash
|
||||
if lora.get('hash') and recipe_scanner:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_hash(lora['hash'])
|
||||
if exists_locally:
|
||||
lora_cache = await lora_scanner.get_cached_data()
|
||||
lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None)
|
||||
if lora_item:
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['localPath'] = lora_item['file_path']
|
||||
lora_entry['file_name'] = lora_item['file_name']
|
||||
lora_entry['size'] = lora_item['size']
|
||||
lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url'])
|
||||
|
||||
else:
|
||||
lora_entry['existsLocally'] = False
|
||||
lora_entry['localPath'] = None
|
||||
|
||||
# Try to get additional info from Civitai if we have a model version ID
|
||||
if lora.get('modelVersionId') and metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(lora['modelVersionId'])
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info_tuple,
|
||||
recipe_scanner,
|
||||
None, # No need to track base model counts
|
||||
lora['hash']
|
||||
)
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
|
||||
loras.append(lora_entry)
|
||||
|
||||
logger.info(f"Found {len(loras)} loras in recipe metadata")
|
||||
|
||||
# Filter gen_params to only include recognized keys
|
||||
filtered_gen_params = {}
|
||||
if 'gen_params' in recipe_metadata:
|
||||
for key, value in recipe_metadata['gen_params'].items():
|
||||
if key in GEN_PARAM_KEYS:
|
||||
filtered_gen_params[key] = value
|
||||
|
||||
return {
|
||||
'base_model': recipe_metadata.get('base_model', ''),
|
||||
'loras': loras,
|
||||
'gen_params': filtered_gen_params,
|
||||
'tags': recipe_metadata.get('tags', []),
|
||||
'title': recipe_metadata.get('title', ''),
|
||||
'from_recipe_metadata': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing recipe format metadata: {e}", exc_info=True)
|
||||
return {"error": str(e), "loras": []}
|
||||
@@ -1,818 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict, List
|
||||
|
||||
from ..services.file_monitor import LoraFileMonitor
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..config import config
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from operator import itemgetter
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.settings_manager import settings
|
||||
import asyncio
|
||||
from .update_routes import UpdateRoutes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiRoutes:
|
||||
"""API route handlers for LoRA management"""
|
||||
|
||||
def __init__(self, file_monitor: LoraFileMonitor):
|
||||
self.scanner = LoraScanner()
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.download_manager = DownloadManager(file_monitor)
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
|
||||
"""Register API routes"""
|
||||
routes = cls(monitor)
|
||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
app.router.add_get('/api/loras', routes.get_loras)
|
||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
||||
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
|
||||
app.router.add_post('/api/download-lora', routes.download_lora)
|
||||
app.router.add_post('/api/settings', routes.update_settings)
|
||||
app.router.add_post('/api/move_model', routes.move_model)
|
||||
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
||||
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
|
||||
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
|
||||
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
|
||||
app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags
|
||||
|
||||
# Add update check routes
|
||||
UpdateRoutes.setup_routes(app)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='Model path is required', status=400)
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await self._delete_model_files(target_dir, file_name)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'deleted_files': deleted_files
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
|
||||
|
||||
# Check if model is from CivitAI
|
||||
local_metadata = await self._load_local_metadata(metadata_path)
|
||||
|
||||
# Fetch and update metadata
|
||||
civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"])
|
||||
if not civitai_metadata:
|
||||
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
|
||||
|
||||
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement request"""
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
preview_data, content_type = await self._read_preview_file(reader)
|
||||
model_path = await self._read_model_path(reader)
|
||||
|
||||
preview_path = await self._save_preview_file(model_path, preview_data, content_type)
|
||||
await self._update_preview_metadata(model_path, preview_path)
|
||||
|
||||
# Update preview URL in scanner cache
|
||||
await self.scanner.update_preview_in_cache(model_path, preview_path)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"preview_url": config.get_preview_static_url(preview_path)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_loras(self, request: web.Request) -> web.Response:
|
||||
"""Handle paginated LoRA data request"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = int(request.query.get('page_size', '20'))
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder')
|
||||
search = request.query.get('search', '').lower()
|
||||
fuzzy = request.query.get('fuzzy', 'false').lower() == 'true'
|
||||
recursive = request.query.get('recursive', 'false').lower() == 'true'
|
||||
|
||||
# Parse base models filter parameter
|
||||
base_models = request.query.get('base_models', '').split(',')
|
||||
base_models = [model.strip() for model in base_models if model.strip()]
|
||||
|
||||
# Parse search options
|
||||
search_filename = request.query.get('search_filename', 'true').lower() == 'true'
|
||||
search_modelname = request.query.get('search_modelname', 'true').lower() == 'true'
|
||||
search_tags = request.query.get('search_tags', 'false').lower() == 'true'
|
||||
|
||||
# Validate parameters
|
||||
if page < 1 or page_size < 1 or page_size > 100:
|
||||
return web.json_response({
|
||||
'error': 'Invalid pagination parameters'
|
||||
}, status=400)
|
||||
|
||||
if sort_by not in ['date', 'name']:
|
||||
return web.json_response({
|
||||
'error': 'Invalid sort parameter'
|
||||
}, status=400)
|
||||
|
||||
# Parse tags filter parameter
|
||||
tags = request.query.get('tags', '').split(',')
|
||||
tags = [tag.strip() for tag in tags if tag.strip()]
|
||||
|
||||
# Get paginated data with search and filters
|
||||
result = await self.scanner.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy=fuzzy,
|
||||
recursive=recursive,
|
||||
base_models=base_models, # Pass base models filter
|
||||
tags=tags, # Add tags parameter
|
||||
search_options={
|
||||
'filename': search_filename,
|
||||
'modelname': search_modelname,
|
||||
'tags': search_tags
|
||||
}
|
||||
)
|
||||
|
||||
# Format the response data
|
||||
formatted_items = [
|
||||
self._format_lora_response(item)
|
||||
for item in result['items']
|
||||
]
|
||||
|
||||
# Get all available folders from cache
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
return web.json_response({
|
||||
'items': formatted_items,
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages'],
|
||||
'folders': cache.folders
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_loras: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
def _format_lora_response(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora["size"],
|
||||
"modified": lora["modified"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"usage_tips": lora.get("usage_tips", ""),
|
||||
"notes": lora.get("notes", ""),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
# Private helper methods
|
||||
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete model and associated files"""
|
||||
patterns = [
|
||||
f"{file_name}.safetensors", # Required
|
||||
f"{file_name}.metadata.json",
|
||||
f"{file_name}.preview.png",
|
||||
f"{file_name}.preview.jpg",
|
||||
f"{file_name}.preview.jpeg",
|
||||
f"{file_name}.preview.webp",
|
||||
f"{file_name}.preview.mp4",
|
||||
f"{file_name}.png",
|
||||
f"{file_name}.jpg",
|
||||
f"{file_name}.jpeg",
|
||||
f"{file_name}.webp",
|
||||
f"{file_name}.mp4"
|
||||
]
|
||||
|
||||
deleted = []
|
||||
main_file = patterns[0]
|
||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(main_path):
|
||||
# Notify file monitor to ignore delete event
|
||||
self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0)
|
||||
|
||||
# Delete file
|
||||
os.remove(main_path)
|
||||
deleted.append(main_path)
|
||||
else:
|
||||
logger.warning(f"Model file not found: {main_file}")
|
||||
|
||||
# Remove from cache
|
||||
cache = await self.scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path]
|
||||
await cache.resort()
|
||||
|
||||
# Delete optional files
|
||||
for pattern in patterns[1:]:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted.append(pattern)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {pattern}: {e}")
|
||||
|
||||
return deleted
|
||||
|
||||
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
|
||||
"""Read preview file and content type from multipart request"""
|
||||
field = await reader.next()
|
||||
if field.name != 'preview_file':
|
||||
raise ValueError("Expected 'preview_file' field")
|
||||
content_type = field.headers.get('Content-Type', 'image/png')
|
||||
return await field.read(), content_type
|
||||
|
||||
async def _read_model_path(self, reader) -> str:
|
||||
"""Read model path from multipart request"""
|
||||
field = await reader.next()
|
||||
if field.name != 'model_path':
|
||||
raise ValueError("Expected 'model_path' field")
|
||||
return (await field.read()).decode()
|
||||
|
||||
async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str:
|
||||
"""Save preview file and return its path"""
|
||||
# Determine file extension based on content type
|
||||
if content_type.startswith('video/'):
|
||||
extension = '.preview.mp4'
|
||||
else:
|
||||
extension = '.preview.png'
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
folder = os.path.dirname(model_path)
|
||||
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
|
||||
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(preview_data)
|
||||
|
||||
return preview_path
|
||||
|
||||
async def _update_preview_metadata(self, model_path: str, preview_path: str):
|
||||
"""Update preview path in metadata"""
|
||||
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Update preview_url directly in the metadata dict
|
||||
metadata['preview_url'] = preview_path
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating metadata: {e}")
|
||||
|
||||
async def _load_local_metadata(self, metadata_path: str) -> Dict:
|
||||
"""Load local metadata file"""
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
||||
return {}
|
||||
|
||||
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
|
||||
"""Handle case when model is not found on CivitAI"""
|
||||
local_metadata['from_civitai'] = False
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Not found on CivitAI"},
|
||||
status=404
|
||||
)
|
||||
|
||||
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
|
||||
civitai_metadata: Dict, client: CivitaiClient) -> None:
|
||||
"""Update local metadata with CivitAI data"""
|
||||
local_metadata['civitai'] = civitai_metadata
|
||||
|
||||
# Update model name if available
|
||||
if 'model' in civitai_metadata:
|
||||
local_metadata['model_name'] = civitai_metadata['model'].get('name',
|
||||
local_metadata.get('model_name'))
|
||||
|
||||
# Fetch additional model metadata (description and tags) if we have model ID
|
||||
model_id = civitai_metadata['modelId']
|
||||
if model_id:
|
||||
model_metadata = await client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
local_metadata['modelDescription'] = model_metadata.get('description', '')
|
||||
local_metadata['tags'] = model_metadata.get('tags', [])
|
||||
|
||||
# Update base model
|
||||
local_metadata['base_model'] = civitai_metadata.get('baseModel')
|
||||
|
||||
# Update preview if needed
|
||||
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
||||
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
||||
if first_preview:
|
||||
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_filename = base_name + preview_ext
|
||||
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
|
||||
|
||||
if await client.download_preview_image(first_preview['url'], preview_path):
|
||||
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
await self.scanner.update_single_lora_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all loras in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# 准备要处理的 loras
|
||||
to_process = [
|
||||
lora for lora in cache.raw_data
|
||||
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# 发送初始进度
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
for lora in to_process:
|
||||
try:
|
||||
original_name = lora.get('model_name')
|
||||
if await self._fetch_and_update_single_lora(
|
||||
sha256=lora['sha256'],
|
||||
file_path=lora['file_path'],
|
||||
lora=lora
|
||||
):
|
||||
success += 1
|
||||
if original_name != lora.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# 每处理一个就发送进度更新
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': lora.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# 发送完成消息
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed loras (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# 发送错误消息
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
|
||||
"""Fetch and update metadata for a single lora without sorting
|
||||
|
||||
Args:
|
||||
sha256: SHA256 hash of the lora file
|
||||
file_path: Path to the lora file
|
||||
lora: The lora object in cache to update
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
client = CivitaiClient()
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Check if model is from CivitAI
|
||||
local_metadata = await self._load_local_metadata(metadata_path)
|
||||
|
||||
# Fetch metadata
|
||||
civitai_metadata = await client.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
# Mark as not from CivitAI if not found
|
||||
local_metadata['from_civitai'] = False
|
||||
lora['from_civitai'] = False
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
return False
|
||||
|
||||
# Update metadata
|
||||
await self._update_model_metadata(
|
||||
metadata_path,
|
||||
local_metadata,
|
||||
civitai_metadata,
|
||||
client
|
||||
)
|
||||
|
||||
# Update cache object directly
|
||||
lora.update({
|
||||
'model_name': local_metadata.get('model_name'),
|
||||
'preview_url': local_metadata.get('preview_url'),
|
||||
'from_civitai': True,
|
||||
'civitai': civitai_metadata
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data: {e}")
|
||||
return False
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
async def get_lora_roots(self, request: web.Request) -> web.Response:
|
||||
"""Get all configured LoRA root directories"""
|
||||
return web.json_response({
|
||||
'roots': config.loras_roots
|
||||
})
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
versions = await self.civitai_client.get_model_versions(model_id)
|
||||
if not versions:
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
for file in version.get('files', []):
|
||||
sha256 = file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
file['existsLocally'] = self.scanner.has_lora_hash(sha256)
|
||||
if file['existsLocally']:
|
||||
file['localPath'] = self.scanner.get_lora_path_by_hash(sha256)
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def download_lora(self, request: web.Request) -> web.Response:
|
||||
async with self._download_lock:
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
async def progress_callback(progress):
|
||||
await ws_manager.broadcast({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=data.get('download_url'),
|
||||
save_dir=data.get('lora_root'),
|
||||
relative_path=data.get('relative_path'),
|
||||
progress_callback=progress_callback # Add progress callback
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
return web.Response(status=500, text=result.get('error', 'Unknown error'))
|
||||
|
||||
return web.json_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading LoRA: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def update_settings(self, request: web.Request) -> web.Response:
|
||||
"""Update application settings"""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Validate and update settings
|
||||
if 'civitai_api_key' in data:
|
||||
settings.set('civitai_api_key', data['civitai_api_key'])
|
||||
|
||||
return web.json_response({'success': True})
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model move request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
target_path = data.get('target_path')
|
||||
|
||||
if not file_path or not target_path:
|
||||
return web.Response(text='File path and target path are required', status=400)
|
||||
|
||||
# Call scanner to handle the move operation
|
||||
success = await self.scanner.move_model(file_path, target_path)
|
||||
|
||||
if success:
|
||||
return web.json_response({'success': True})
|
||||
else:
|
||||
return web.Response(text='Failed to move model', status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@classmethod
|
||||
async def cleanup(cls):
|
||||
"""Add cleanup method for application shutdown"""
|
||||
if hasattr(cls, '_instance'):
|
||||
await cls._instance.civitai_client.close()
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Handle nested updates (for civitai.trainedWords)
|
||||
for key, value in metadata_updates.items():
|
||||
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||
# Deep update for nested dictionaries
|
||||
for nested_key, nested_value in value.items():
|
||||
metadata[key][nested_key] = nested_value
|
||||
else:
|
||||
# Regular update for top-level keys
|
||||
metadata[key] = value
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_lora_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
# Get lora file name from query parameters
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
# Get cache data
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Search for the lora in cache data
|
||||
for lora in cache.raw_data:
|
||||
file_name = lora['file_name']
|
||||
if file_name == lora_name:
|
||||
if preview_url := lora.get('preview_url'):
|
||||
# Convert preview path to static URL
|
||||
static_url = config.get_preview_static_url(preview_url)
|
||||
if static_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': static_url
|
||||
})
|
||||
break
|
||||
|
||||
# If no preview URL found
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk model move request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', [])
|
||||
target_path = data.get('target_path')
|
||||
|
||||
if not file_paths or not target_path:
|
||||
return web.Response(text='File paths and target path are required', status=400)
|
||||
|
||||
results = []
|
||||
for file_path in file_paths:
|
||||
success = await self.scanner.move_model(file_path, target_path)
|
||||
results.append({"path": file_path, "success": success})
|
||||
|
||||
# Count successes
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
|
||||
if success_count == len(file_paths):
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully moved {success_count} models'
|
||||
})
|
||||
elif success_count > 0:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||
'results': results
|
||||
})
|
||||
else:
|
||||
return web.Response(text='Failed to move any models', status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
|
||||
if not model_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Check if we already have the description stored in metadata
|
||||
description = None
|
||||
tags = []
|
||||
if file_path:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
description = metadata.get('modelDescription')
|
||||
tags = metadata.get('tags', [])
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
||||
|
||||
# If description is not in metadata, fetch from CivitAI
|
||||
if not description:
|
||||
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||
model_metadata = await self.civitai_client.get_model_metadata(model_id)
|
||||
|
||||
if model_metadata:
|
||||
description = model_metadata.get('description')
|
||||
tags = model_metadata.get('tags', [])
|
||||
|
||||
# Save the metadata to file if we have a file path and got metadata
|
||||
if file_path:
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
metadata['modelDescription'] = description
|
||||
metadata['tags'] = tags
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved model metadata to file for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model metadata: {e}")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'description': description or "<p>No model description available.</p>",
|
||||
'tags': tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model metadata: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
275
py/routes/base_model_routes.py
Normal file
275
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..services.download_coordinator import DownloadCoordinator
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.model_lifecycle_service import ModelLifecycleService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.use_cases import (
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelUseCase,
|
||||
)
|
||||
from ..services.websocket_progress_callback import (
|
||||
WebSocketBroadcastCallback,
|
||||
WebSocketProgressCallback,
|
||||
)
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||
from .handlers.model_handlers import (
|
||||
ModelAutoOrganizeHandler,
|
||||
ModelCivitaiHandler,
|
||||
ModelDownloadHandler,
|
||||
ModelHandlerSet,
|
||||
ModelListingHandler,
|
||||
ModelManagementHandler,
|
||||
ModelMoveHandler,
|
||||
ModelPageView,
|
||||
ModelQueryHandler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModelRoutes(ABC):
|
||||
"""Base route controller for all model types."""
|
||||
|
||||
template_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service=None,
|
||||
*,
|
||||
settings_service=default_settings,
|
||||
ws_manager=default_ws_manager,
|
||||
server_i18n=default_server_i18n,
|
||||
metadata_provider_factory=get_default_metadata_provider,
|
||||
) -> None:
|
||||
self.service = None
|
||||
self.model_type = ""
|
||||
self._settings = settings_service
|
||||
self._ws_manager = ws_manager
|
||||
self._server_i18n = server_i18n
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self.model_file_service: ModelFileService | None = None
|
||||
self.model_move_service: ModelMoveService | None = None
|
||||
self.model_lifecycle_service: ModelLifecycleService | None = None
|
||||
self.websocket_progress_callback = WebSocketProgressCallback()
|
||||
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
||||
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
self._preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=settings_service,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||
self._download_coordinator = DownloadCoordinator(
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
def attach_service(self, service) -> None:
|
||||
"""Attach a model service and rebuild handler dependencies."""
|
||||
self.service = service
|
||||
self.model_type = service.model_type
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
def _create_handler_set(self) -> ModelHandlerSet:
|
||||
service = self._ensure_service()
|
||||
page_view = ModelPageView(
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name or "",
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
server_i18n=self._server_i18n,
|
||||
logger=logger,
|
||||
)
|
||||
listing = ModelListingHandler(
|
||||
service=service,
|
||||
parse_specific_params=self._parse_specific_params,
|
||||
logger=logger,
|
||||
)
|
||||
management = ModelManagementHandler(
|
||||
service=service,
|
||||
logger=logger,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
preview_service=self._preview_service,
|
||||
tag_update_service=self._tag_update_service,
|
||||
lifecycle_service=self._ensure_lifecycle_service(),
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_use_case=download_use_case,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
settings_service=self._settings,
|
||||
logger=logger,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
validate_model_type=self._validate_civitai_model_type,
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
metadata_refresh_use_case=metadata_refresh_use_case,
|
||||
metadata_progress_callback=self.metadata_progress_callback,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize_use_case = AutoOrganizeUseCase(
|
||||
file_service=self._ensure_file_service(),
|
||||
lock_provider=self._ws_manager,
|
||||
)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
use_case=auto_organize_use_case,
|
||||
progress_callback=self.websocket_progress_callback,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
return ModelHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
management=management,
|
||||
query=query,
|
||||
download=download,
|
||||
civitai=civitai,
|
||||
move=move,
|
||||
auto_organize=auto_organize,
|
||||
)
|
||||
|
||||
@property
|
||||
def route_handlers(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
return self._ensure_handler_mapping()
|
||||
|
||||
def setup_routes(self, app: web.Application, prefix: str) -> None:
|
||||
registrar = ModelRouteRegistrar(app)
|
||||
handler_lookup = {
|
||||
definition.handler_name: self._make_handler_proxy(definition.handler_name)
|
||||
for definition in COMMON_ROUTE_DEFINITIONS
|
||||
}
|
||||
registrar.register_common_routes(prefix, handler_lookup)
|
||||
self.setup_specific_routes(registrar, prefix)
|
||||
|
||||
@abstractmethod
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str) -> None:
|
||||
"""Setup model-specific routes."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse model-specific parameters - to be overridden by subclasses."""
|
||||
return {}
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type - to be overridden by subclasses."""
|
||||
return True
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages - to be overridden by subclasses."""
|
||||
return "any model type"
|
||||
|
||||
def _find_model_file(self, files):
|
||||
"""Find the appropriate model file from the files list - can be overridden by subclasses."""
|
||||
return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None)
|
||||
|
||||
def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
"""Expose handlers for subclasses or tests."""
|
||||
return self._ensure_handler_mapping()[name]
|
||||
|
||||
def _ensure_service(self):
|
||||
if self.service is None:
|
||||
raise RuntimeError("Model service has not been attached")
|
||||
return self.service
|
||||
|
||||
def _ensure_file_service(self) -> ModelFileService:
|
||||
if self.model_file_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
return self.model_file_service
|
||||
|
||||
def _ensure_move_service(self) -> ModelMoveService:
|
||||
if self.model_move_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
return self.model_move_service
|
||||
|
||||
def _ensure_lifecycle_service(self) -> ModelLifecycleService:
|
||||
if self.model_lifecycle_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
return self.model_lifecycle_service
|
||||
|
||||
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
handler = self.get_handler(name)
|
||||
except RuntimeError:
|
||||
return web.json_response({"success": False, "error": "Service not ready"}, status=503)
|
||||
return await handler(request)
|
||||
|
||||
return proxy
|
||||
|
||||
217
py/routes/base_recipe_routes.py
Normal file
217
py/routes/base_recipe_routes.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Base infrastructure shared across recipe routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
)
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
RecipeAnalysisHandler,
|
||||
RecipeHandlerSet,
|
||||
RecipeListingHandler,
|
||||
RecipeManagementHandler,
|
||||
RecipePageView,
|
||||
RecipeQueryHandler,
|
||||
RecipeSharingHandler,
|
||||
)
|
||||
from .recipe_route_registrar import ROUTE_DEFINITIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRecipeRoutes:
|
||||
"""Common dependency and startup wiring for recipe routes."""
|
||||
|
||||
_HANDLER_NAMES: tuple[str, ...] = tuple(
|
||||
definition.handler_name for definition in ROUTE_DEFINITIONS
|
||||
)
|
||||
|
||||
template_name: str = "recipes.html"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.recipe_scanner = None
|
||||
self.lora_scanner = None
|
||||
self.civitai_client = None
|
||||
self.settings = settings
|
||||
self.server_i18n = server_i18n
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self._i18n_registered = False
|
||||
self._startup_hooks_registered = False
|
||||
self._handler_set: RecipeHandlerSet | None = None
|
||||
self._handler_mapping: dict[str, Callable] | None = None
|
||||
|
||||
async def attach_dependencies(self, app: web.Application | None = None) -> None:
|
||||
"""Resolve shared services from the registry."""
|
||||
|
||||
await self._ensure_services()
|
||||
self._ensure_i18n_filter()
|
||||
|
||||
async def ensure_dependencies_ready(self) -> None:
|
||||
"""Ensure dependencies are available for request handlers."""
|
||||
|
||||
if self.recipe_scanner is None or self.civitai_client is None:
|
||||
await self.attach_dependencies()
|
||||
|
||||
def register_startup_hooks(self, app: web.Application) -> None:
|
||||
"""Register startup hooks once for dependency wiring."""
|
||||
|
||||
if self._startup_hooks_registered:
|
||||
return
|
||||
|
||||
app.on_startup.append(self.attach_dependencies)
|
||||
app.on_startup.append(self.prewarm_cache)
|
||||
self._startup_hooks_registered = True
|
||||
|
||||
async def prewarm_cache(self, app: web.Application | None = None) -> None:
|
||||
"""Pre-load recipe and LoRA caches on startup."""
|
||||
|
||||
try:
|
||||
await self.attach_dependencies(app)
|
||||
|
||||
if self.lora_scanner is not None:
|
||||
await self.lora_scanner.get_cached_data()
|
||||
hash_index = getattr(self.lora_scanner, "_hash_index", None)
|
||||
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
||||
_ = len(hash_index._hash_to_path)
|
||||
|
||||
if self.recipe_scanner is not None:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as exc:
|
||||
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable]:
|
||||
"""Return a mapping of handler name to coroutine for registrar binding."""
|
||||
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
async def _ensure_services(self) -> None:
|
||||
if self.recipe_scanner is None:
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None)
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
def _ensure_i18n_filter(self) -> None:
|
||||
if not self._i18n_registered:
|
||||
self.template_env.filters["t"] = self.server_i18n.create_template_filter()
|
||||
self._i18n_registered = True
|
||||
|
||||
def get_handler_owner(self):
|
||||
"""Return the object supplying bound handler coroutines."""
|
||||
|
||||
if self._handler_set is None:
|
||||
self._handler_set = self._create_handler_set()
|
||||
return self._handler_set
|
||||
|
||||
def _create_handler_set(self) -> RecipeHandlerSet:
|
||||
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||
civitai_client_getter = lambda: self.civitai_client
|
||||
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||
MetadataProcessor,
|
||||
)
|
||||
from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found]
|
||||
MetadataRegistry,
|
||||
)
|
||||
else: # pragma: no cover - optional dependency path
|
||||
get_metadata = None # type: ignore[assignment]
|
||||
MetadataProcessor = None # type: ignore[assignment]
|
||||
MetadataRegistry = None # type: ignore[assignment]
|
||||
|
||||
analysis_service = RecipeAnalysisService(
|
||||
exif_utils=ExifUtils,
|
||||
recipe_parser_factory=RecipeParserFactory,
|
||||
downloader_factory=get_downloader,
|
||||
metadata_collector=get_metadata,
|
||||
metadata_processor_cls=MetadataProcessor,
|
||||
metadata_registry_cls=MetadataRegistry,
|
||||
standalone_mode=standalone_mode,
|
||||
logger=logger,
|
||||
)
|
||||
persistence_service = RecipePersistenceService(
|
||||
exif_utils=ExifUtils,
|
||||
card_preview_width=CARD_PREVIEW_WIDTH,
|
||||
logger=logger,
|
||||
)
|
||||
sharing_service = RecipeSharingService(logger=logger)
|
||||
|
||||
page_view = RecipePageView(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
settings_service=self.settings,
|
||||
server_i18n=self.server_i18n,
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
listing = RecipeListingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
query = RecipeQueryHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
format_recipe_file_url=listing.format_recipe_file_url,
|
||||
logger=logger,
|
||||
)
|
||||
management = RecipeManagementHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
persistence_service=persistence_service,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
analysis = RecipeAnalysisHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
logger=logger,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
sharing = RecipeSharingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
sharing_service=sharing_service,
|
||||
)
|
||||
|
||||
return RecipeHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
query=query,
|
||||
management=management,
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
)
|
||||
|
||||
96
py/routes/checkpoint_routes.py
Normal file
96
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Checkpoint-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||
super().__init__()
|
||||
self.template_name = "checkpoints.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Checkpoint routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||
super().setup_routes(app, 'checkpoints')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Checkpoint-specific routes"""
|
||||
# Checkpoint info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info)
|
||||
|
||||
# Checkpoint roots and Unet roots
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Checkpoint"""
|
||||
return model_type.lower() == 'checkpoint'
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "Checkpoint"
|
||||
|
||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of checkpoint roots from config"""
|
||||
try:
|
||||
roots = config.checkpoints_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_unet_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of unet roots from config"""
|
||||
try:
|
||||
roots = config.unet_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
61
py/routes/embedding_routes.py
Normal file
61
py/routes/embedding_routes.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Embedding-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Embedding routes with Embedding service"""
|
||||
super().__init__()
|
||||
self.template_name = "embeddings.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.service = EmbeddingService(embedding_scanner)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Embedding routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'embeddings' prefix (includes page route)
|
||||
super().setup_routes(app, 'embeddings')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Embedding-specific routes"""
|
||||
# Embedding info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_embedding_info)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Embedding"""
|
||||
return model_type.lower() == 'textualinversion'
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "TextualInversion"
|
||||
|
||||
async def get_embedding_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific embedding by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
embedding_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if embedding_info:
|
||||
return web.json_response(embedding_info)
|
||||
else:
|
||||
return web.json_response({"error": "Embedding not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
62
py/routes/example_images_route_registrar.py
Normal file
62
py/routes/example_images_route_registrar.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Route registrar for example image endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative configuration for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"),
|
||||
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
|
||||
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
|
||||
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
|
||||
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"),
|
||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||
)
|
||||
|
||||
|
||||
class ExampleImagesRouteRegistrar:
|
||||
"""Bind declarative example image routes to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(
|
||||
self,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
"""Register each route definition using the supplied handlers."""
|
||||
|
||||
for definition in definitions:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
88
py/routes/example_images_routes.py
Normal file
88
py/routes/example_images_routes.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .example_images_route_registrar import ExampleImagesRouteRegistrar
|
||||
from .handlers.example_images_handlers import (
|
||||
ExampleImagesDownloadHandler,
|
||||
ExampleImagesFileHandler,
|
||||
ExampleImagesHandlerSet,
|
||||
ExampleImagesManagementHandler,
|
||||
)
|
||||
from ..services.use_cases.example_images import (
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
)
|
||||
from ..utils.example_images_download_manager import (
|
||||
DownloadManager,
|
||||
get_default_download_manager,
|
||||
)
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
from ..services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleImagesRoutes:
|
||||
"""Route controller for example image endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager: DownloadManager | None = None,
|
||||
processor=ExampleImagesProcessor,
|
||||
file_manager=ExampleImagesFileManager,
|
||||
cleanup_service: ExampleImagesCleanupService | None = None,
|
||||
) -> None:
|
||||
if ws_manager is None:
|
||||
raise ValueError("ws_manager is required")
|
||||
self._download_manager = download_manager or get_default_download_manager(ws_manager)
|
||||
self._processor = processor
|
||||
self._file_manager = file_manager
|
||||
self._cleanup_service = cleanup_service or ExampleImagesCleanupService()
|
||||
self._handler_set: ExampleImagesHandlerSet | None = None
|
||||
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
|
||||
"""Register routes on the given aiohttp application using default wiring."""
|
||||
|
||||
controller = cls(ws_manager=ws_manager)
|
||||
controller.register(app)
|
||||
|
||||
def register(self, app: web.Application) -> None:
|
||||
"""Bind the controller's handlers to the aiohttp router."""
|
||||
|
||||
registrar = ExampleImagesRouteRegistrar(app)
|
||||
registrar.register_routes(self.to_route_mapping())
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Return the registrar-compatible mapping of handler names to callables."""
|
||||
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._build_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
def _build_handler_set(self) -> ExampleImagesHandlerSet:
|
||||
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
|
||||
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
|
||||
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
|
||||
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
|
||||
management_handler = ExampleImagesManagementHandler(
|
||||
import_use_case,
|
||||
self._processor,
|
||||
self._cleanup_service,
|
||||
)
|
||||
file_handler = ExampleImagesFileHandler(self._file_manager)
|
||||
return ExampleImagesHandlerSet(
|
||||
download=download_handler,
|
||||
management=management_handler,
|
||||
files=file_handler,
|
||||
)
|
||||
159
py/routes/handlers/example_images_handlers.py
Normal file
159
py/routes/handlers/example_images_handlers.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Handler set for example image routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.use_cases.example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from ...utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
DownloadNotRunningError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from ...utils.example_images_processor import ExampleImagesImportError
|
||||
|
||||
|
||||
class ExampleImagesDownloadHandler:
|
||||
"""HTTP adapters for download-related example image endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
download_use_case: DownloadExampleImagesUseCase,
|
||||
download_manager,
|
||||
) -> None:
|
||||
self._download_use_case = download_use_case
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_use_case.execute(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadExampleImagesInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadExampleImagesConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._download_manager.get_status(request)
|
||||
return web.json_response(result)
|
||||
|
||||
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.pause_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.resume_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_manager.start_force_download(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress_snapshot,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
|
||||
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor, cleanup_service) -> None:
|
||||
self._import_use_case = import_use_case
|
||||
self._processor = processor
|
||||
self._cleanup_service = cleanup_service
|
||||
|
||||
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._import_use_case.execute(request)
|
||||
return web.json_response(result)
|
||||
except ImportExampleImagesValidationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesImportError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._processor.delete_custom_image(request)
|
||||
|
||||
async def cleanup_example_image_folders(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._cleanup_service.cleanup_example_image_folders()
|
||||
|
||||
if result.get('success') or result.get('partial_success'):
|
||||
return web.json_response(result, status=200)
|
||||
|
||||
error_code = result.get('error_code')
|
||||
status = 400 if error_code in {'path_not_configured', 'path_not_found'} else 500
|
||||
return web.json_response(result, status=status)
|
||||
|
||||
|
||||
class ExampleImagesFileHandler:
|
||||
"""HTTP adapters for filesystem-centric endpoints."""
|
||||
|
||||
def __init__(self, file_manager) -> None:
|
||||
self._file_manager = file_manager
|
||||
|
||||
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.open_folder(request)
|
||||
|
||||
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.get_files(request)
|
||||
|
||||
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.has_images(request)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExampleImagesHandlerSet:
|
||||
"""Aggregate of handlers exposed to the registrar."""
|
||||
|
||||
download: ExampleImagesDownloadHandler
|
||||
management: ExampleImagesManagementHandler
|
||||
files: ExampleImagesFileHandler
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Flatten handler methods into the registrar mapping."""
|
||||
|
||||
return {
|
||||
"download_example_images": self.download.download_example_images,
|
||||
"get_example_images_status": self.download.get_example_images_status,
|
||||
"pause_example_images": self.download.pause_example_images,
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
"cleanup_example_image_folders": self.management.cleanup_example_image_folders,
|
||||
"open_example_images_folder": self.files.open_example_images_folder,
|
||||
"get_example_image_files": self.files.get_example_image_files,
|
||||
"has_example_images": self.files.has_example_images,
|
||||
}
|
||||
1020
py/routes/handlers/model_handlers.py
Normal file
1020
py/routes/handlers/model_handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
725
py/routes/handlers/recipe_handlers.py
Normal file
725
py/routes/handlers/recipe_handlers.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""Dedicated handler objects for recipe-related routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.server_i18n import server_i18n as default_server_i18n
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
RecipeValidationError,
|
||||
)
|
||||
|
||||
Logger = logging.Logger
|
||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||
RecipeScannerGetter = Callable[[], Any]
|
||||
CivitaiClientGetter = Callable[[], Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RecipeHandlerSet:
|
||||
"""Group of handlers providing recipe route implementations."""
|
||||
|
||||
page_view: "RecipePageView"
|
||||
listing: "RecipeListingHandler"
|
||||
query: "RecipeQueryHandler"
|
||||
management: "RecipeManagementHandler"
|
||||
analysis: "RecipeAnalysisHandler"
|
||||
sharing: "RecipeSharingHandler"
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
"""Expose handler coroutines keyed by registrar handler names."""
|
||||
|
||||
return {
|
||||
"render_page": self.page_view.render_page,
|
||||
"list_recipes": self.listing.list_recipes,
|
||||
"get_recipe": self.listing.get_recipe,
|
||||
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||
"analyze_local_image": self.analysis.analyze_local_image,
|
||||
"save_recipe": self.management.save_recipe,
|
||||
"delete_recipe": self.management.delete_recipe,
|
||||
"get_top_tags": self.query.get_top_tags,
|
||||
"get_base_models": self.query.get_base_models,
|
||||
"share_recipe": self.sharing.share_recipe,
|
||||
"download_shared_recipe": self.sharing.download_shared_recipe,
|
||||
"get_recipe_syntax": self.query.get_recipe_syntax,
|
||||
"update_recipe": self.management.update_recipe,
|
||||
"reconnect_lora": self.management.reconnect_lora,
|
||||
"find_duplicates": self.query.find_duplicates,
|
||||
"bulk_delete": self.management.bulk_delete,
|
||||
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||
"scan_recipes": self.query.scan_recipes,
|
||||
}
|
||||
|
||||
|
||||
class RecipePageView:
|
||||
"""Render the recipe shell page."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
settings_service: SettingsManager,
|
||||
server_i18n=default_server_i18n,
|
||||
template_env,
|
||||
template_name: str,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._settings = settings_service
|
||||
self._server_i18n = server_i18n
|
||||
self._template_env = template_env
|
||||
self._template_name = template_name
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def render_page(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None: # pragma: no cover - defensive guard
|
||||
raise RuntimeError("Recipe scanner not available")
|
||||
|
||||
user_language = self._settings.get("language", "en")
|
||||
self._server_i18n.set_locale(user_language)
|
||||
|
||||
try:
|
||||
await recipe_scanner.get_cached_data(force_refresh=False)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
recipes=[],
|
||||
is_initializing=False,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
except Exception as cache_error: # pragma: no cover - logging path
|
||||
self._logger.error("Error loading recipe cache data: %s", cache_error)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
is_initializing=True,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
return web.Response(text=rendered, content_type="text/html")
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error handling recipes request: %s", exc, exc_info=True)
|
||||
return web.Response(text="Error loading recipes page", status=500)
|
||||
|
||||
|
||||
class RecipeListingHandler:
|
||||
"""Provide listing and detail APIs for recipes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def list_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
page = int(request.query.get("page", "1"))
|
||||
page_size = int(request.query.get("page_size", "20"))
|
||||
sort_by = request.query.get("sort_by", "date")
|
||||
search = request.query.get("search")
|
||||
|
||||
search_options = {
|
||||
"title": request.query.get("search_title", "true").lower() == "true",
|
||||
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
}
|
||||
|
||||
filters: Dict[str, list[str]] = {}
|
||||
base_models = request.query.get("base_models")
|
||||
if base_models:
|
||||
filters["base_model"] = base_models.split(",")
|
||||
|
||||
tags = request.query.get("tags")
|
||||
if tags:
|
||||
filters["tags"] = tags.split(",")
|
||||
|
||||
lora_hash = request.query.get("lora_hash")
|
||||
|
||||
result = await recipe_scanner.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
search=search,
|
||||
filters=filters,
|
||||
search_options=search_options,
|
||||
lora_hash=lora_hash,
|
||||
)
|
||||
|
||||
for item in result.get("items", []):
|
||||
file_path = item.get("file_path")
|
||||
if file_path:
|
||||
item["file_url"] = self.format_recipe_file_url(file_path)
|
||||
else:
|
||||
item.setdefault("file_url", "/loras_static/images/no-preview.png")
|
||||
item.setdefault("loras", [])
|
||||
item.setdefault("base_model", "")
|
||||
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
|
||||
if not recipe:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
return web.json_response(recipe)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def format_recipe_file_url(self, file_path: str) -> str:
|
||||
try:
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/")
|
||||
normalized_path = file_path.replace(os.sep, "/")
|
||||
if normalized_path.startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/")
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
|
||||
class RecipeQueryHandler:
|
||||
"""Provide read-only insights on recipe data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
format_recipe_file_url: Callable[[str], str],
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._format_recipe_file_url = format_recipe_file_url
|
||||
self._logger = logger
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
tag_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
for tag in recipe.get("tags", []) or []:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
|
||||
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving top tags: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
base_model = recipe.get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
lora_hash = request.query.get("hash")
|
||||
if not lora_hash:
|
||||
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||
|
||||
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
self._logger.info("Manually triggering recipe cache rebuild")
|
||||
await recipe_scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def find_duplicates(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
duplicate_groups = await recipe_scanner.find_all_duplicate_recipes()
|
||||
response_data = []
|
||||
|
||||
for fingerprint, recipe_ids in duplicate_groups.items():
|
||||
if len(recipe_ids) <= 1:
|
||||
continue
|
||||
|
||||
recipes = []
|
||||
for recipe_id in recipe_ids:
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if recipe:
|
||||
recipes.append(
|
||||
{
|
||||
"id": recipe.get("id"),
|
||||
"title": recipe.get("title"),
|
||||
"file_url": recipe.get("file_url")
|
||||
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||
"modified": recipe.get("modified"),
|
||||
"created_date": recipe.get("created_date"),
|
||||
"lora_count": len(recipe.get("loras", [])),
|
||||
}
|
||||
)
|
||||
|
||||
if len(recipes) >= 2:
|
||||
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||
response_data.append(
|
||||
{
|
||||
"fingerprint": fingerprint,
|
||||
"count": len(recipes),
|
||||
"recipes": recipes,
|
||||
}
|
||||
)
|
||||
|
||||
response_data.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "duplicate_groups": response_data})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
try:
|
||||
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
|
||||
except RecipeNotFoundError:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
if not syntax_parts:
|
||||
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||
|
||||
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class RecipeManagementHandler:
|
||||
"""Handle create/update/delete style recipe operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
persistence_service: RecipePersistenceService,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._persistence_service = persistence_service
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
reader = await request.multipart()
|
||||
payload = await self._parse_save_payload(reader)
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=payload["image_bytes"],
|
||||
image_base64=payload["image_base64"],
|
||||
name=payload["name"],
|
||||
tags=payload["tags"],
|
||||
metadata=payload["metadata"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._persistence_service.delete_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error deleting recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def update_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
data = await request.json()
|
||||
result = await self._persistence_service.update_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
for field in ("recipe_id", "lora_index", "target_name"):
|
||||
if field not in data:
|
||||
raise RecipeValidationError(f"Missing required field: {field}")
|
||||
|
||||
result = await self._persistence_service.reconnect_lora(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=data["recipe_id"],
|
||||
lora_index=int(data["lora_index"]),
|
||||
target_name=data["target_name"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def bulk_delete(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
recipe_ids = data.get("recipe_ids", [])
|
||||
result = await self._persistence_service.bulk_delete(
|
||||
recipe_scanner=recipe_scanner, recipe_ids=recipe_ids
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error performing bulk delete: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
analysis = await self._analysis_service.analyze_widget_metadata(
|
||||
recipe_scanner=recipe_scanner
|
||||
)
|
||||
metadata = analysis.payload.get("metadata")
|
||||
image_bytes = analysis.payload.get("image_bytes")
|
||||
if not metadata or image_bytes is None:
|
||||
raise RecipeValidationError("Unable to extract metadata from widget")
|
||||
|
||||
result = await self._persistence_service.save_recipe_from_widget(
|
||||
recipe_scanner=recipe_scanner,
|
||||
metadata=metadata,
|
||||
image_bytes=image_bytes,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def _parse_save_payload(self, reader) -> dict[str, Any]:
|
||||
image_bytes: Optional[bytes] = None
|
||||
image_base64: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = []
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
if field.name == "image":
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
image_bytes = bytes(image_chunks)
|
||||
elif field.name == "image_base64":
|
||||
image_base64 = await field.text()
|
||||
elif field.name == "name":
|
||||
name = await field.text()
|
||||
elif field.name == "tags":
|
||||
tags_text = await field.text()
|
||||
try:
|
||||
parsed_tags = json.loads(tags_text)
|
||||
tags = parsed_tags if isinstance(parsed_tags, list) else []
|
||||
except Exception:
|
||||
tags = []
|
||||
elif field.name == "metadata":
|
||||
metadata_text = await field.text()
|
||||
try:
|
||||
metadata = json.loads(metadata_text)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
return {
|
||||
"image_bytes": image_bytes,
|
||||
"image_base64": image_base64,
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
class RecipeAnalysisHandler:
|
||||
"""Analyze images to extract recipe metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
logger: Logger,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
self._logger = logger
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def analyze_uploaded_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
civitai_client = self._civitai_client_getter()
|
||||
if recipe_scanner is None or civitai_client is None:
|
||||
raise RuntimeError("Required services unavailable")
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "multipart/form-data" in content_type:
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "image":
|
||||
raise RecipeValidationError("No image field found")
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
result = await self._analysis_service.analyze_uploaded_image(
|
||||
image_bytes=bytes(image_chunks),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_remote_image(
|
||||
url=data.get("url"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
civitai_client=civitai_client,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
raise RecipeValidationError("Unsupported content type")
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
async def analyze_local_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_local_image(
|
||||
file_path=data.get("path"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing local image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
|
||||
class RecipeSharingHandler:
|
||||
"""Serve endpoints related to recipe sharing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
sharing_service: RecipeSharingService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._sharing_service = sharing_service
|
||||
|
||||
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._sharing_service.share_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error sharing recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
download_info = await self._sharing_service.prepare_download(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.FileResponse(
|
||||
download_info.file_path,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{download_info.download_filename}"'
|
||||
},
|
||||
)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
@@ -1,97 +1,245 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict, List
|
||||
import asyncio
|
||||
import logging
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings # Add this import
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = LoraScanner()
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
super().__init__()
|
||||
self.template_name = "loras.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
# ComfyUI integration
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
return params
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for LoRA"""
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
return model_type.lower() in VALID_LORA_TYPES
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "LORA, LoCon, or DORA"
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
# 不等待缓存数据,直接检查缓存状态
|
||||
is_initializing = (
|
||||
self.scanner._cache is None and
|
||||
(self.scanner._initialization_task is not None and
|
||||
not self.scanner._initialization_task.done())
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# 如果正在初始化,返回一个只包含加载提示的页面
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings # Pass settings to template
|
||||
)
|
||||
else:
|
||||
# 正常流程
|
||||
cache = await self.scanner.get_cached_data()
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings # Pass settings to template
|
||||
)
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
try:
|
||||
relative_path = request.query.get('relative_path')
|
||||
if not relative_path:
|
||||
return web.Response(text='Relative path is required', status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'usage_tips': usage_tips or ''
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
1058
py/routes/misc_routes.py
Normal file
1058
py/routes/misc_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
99
py/routes/model_route_registrar.py
Normal file
99
py/routes/model_route_registrar.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Route registrar for model endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path_template: str
|
||||
handler_name: str
|
||||
|
||||
def build_path(self, prefix: str) -> str:
|
||||
return self.path_template.replace("{prefix}", prefix)
|
||||
|
||||
|
||||
COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/bulk-delete", "bulk_delete_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/verify-duplicates", "verify_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/move_model", "move_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
||||
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||
)
|
||||
|
||||
|
||||
class ModelRouteRegistrar:
|
||||
"""Bind declarative definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_common_routes(
|
||||
self,
|
||||
prefix: str,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
for definition in definitions:
|
||||
self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name])
|
||||
|
||||
def add_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
self._bind_route(method, path, handler)
|
||||
|
||||
def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None:
|
||||
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
64
py/routes/recipe_route_registrar.py
Normal file
64
py/routes/recipe_route_registrar.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Route registrar for recipe endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a recipe HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
)
|
||||
|
||||
|
||||
class RecipeRouteRegistrar:
|
||||
"""Bind declarative recipe definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||
for definition in ROUTE_DEFINITIONS:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
21
py/routes/recipe_routes.py
Normal file
21
py/routes/recipe_routes.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Concrete recipe route configuration."""
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .base_recipe_routes import BaseRecipeRoutes
|
||||
from .recipe_route_registrar import RecipeRouteRegistrar
|
||||
|
||||
|
||||
class RecipeRoutes(BaseRecipeRoutes):
|
||||
"""API route handlers for Recipe management."""
|
||||
|
||||
template_name = "recipes.html"
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes using the declarative registrar."""
|
||||
|
||||
routes = cls()
|
||||
registrar = RecipeRouteRegistrar(app)
|
||||
registrar.register_routes(routes.to_route_mapping())
|
||||
routes.register_startup_hooks(app)
|
||||
519
py/routes/stats_routes.py
Normal file
519
py/routes/stats_routes.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, Counter
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StatsRoutes:
|
||||
"""Route handlers for Statistics page and API endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_scanner = None
|
||||
self.checkpoint_scanner = None
|
||||
self.embedding_scanner = None
|
||||
self.usage_stats = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Only initialize usage stats if we have valid paths configured
|
||||
try:
|
||||
self.usage_stats = UsageStats()
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not initialize usage statistics: {e}")
|
||||
self.usage_stats = None
|
||||
|
||||
async def handle_stats_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /statistics request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if scanners are initializing
|
||||
lora_initializing = (
|
||||
self.lora_scanner._cache is None or
|
||||
(hasattr(self.lora_scanner, 'is_initializing') and self.lora_scanner.is_initializing())
|
||||
)
|
||||
|
||||
checkpoint_initializing = (
|
||||
self.checkpoint_scanner._cache is None or
|
||||
(hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing)
|
||||
)
|
||||
|
||||
embedding_initializing = (
|
||||
self.embedding_scanner._cache is None or
|
||||
(hasattr(self.embedding_scanner, 'is_initializing') and self.embedding_scanner.is_initializing())
|
||||
)
|
||||
|
||||
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
||||
|
||||
# 获取用户语言设置
|
||||
user_language = settings.get('language', 'en')
|
||||
|
||||
# 设置服务端i18n语言
|
||||
server_i18n.set_locale(user_language)
|
||||
|
||||
# 为模板环境添加i18n过滤器
|
||||
if not hasattr(self.template_env, '_i18n_filter_added'):
|
||||
self.template_env.filters['t'] = server_i18n.create_template_filter()
|
||||
self.template_env._i18n_filter_added = True
|
||||
|
||||
template = self.template_env.get_template('statistics.html')
|
||||
rendered = template.render(
|
||||
is_initializing=is_initializing,
|
||||
settings=settings,
|
||||
request=request,
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling statistics request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading statistics page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_collection_overview(self, request: web.Request) -> web.Response:
|
||||
"""Get collection overview statistics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get LoRA statistics
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
lora_count = len(lora_cache.raw_data)
|
||||
lora_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data)
|
||||
|
||||
# Get Checkpoint statistics
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
checkpoint_count = len(checkpoint_cache.raw_data)
|
||||
checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
|
||||
|
||||
# Get Embedding statistics
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
embedding_count = len(embedding_cache.raw_data)
|
||||
embedding_size = sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'total_models': lora_count + checkpoint_count + embedding_count,
|
||||
'lora_count': lora_count,
|
||||
'checkpoint_count': checkpoint_count,
|
||||
'embedding_count': embedding_count,
|
||||
'total_size': lora_size + checkpoint_size + embedding_size,
|
||||
'lora_size': lora_size,
|
||||
'checkpoint_size': checkpoint_size,
|
||||
'embedding_size': embedding_size,
|
||||
'total_generations': usage_data.get('total_executions', 0),
|
||||
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
|
||||
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})),
|
||||
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collection overview: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_usage_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get usage analytics data"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data for enrichment
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create hash to model mapping
|
||||
lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data}
|
||||
checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data}
|
||||
embedding_map = {emb['sha256']: emb for emb in embedding_cache.raw_data}
|
||||
|
||||
# Prepare top used models
|
||||
top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10)
|
||||
top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10)
|
||||
top_embeddings = self._get_top_used_models(usage_data.get('embeddings', {}), embedding_map, 10)
|
||||
|
||||
# Prepare usage timeline (last 30 days)
|
||||
timeline = self._get_usage_timeline(usage_data, 30)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'top_loras': top_loras,
|
||||
'top_checkpoints': top_checkpoints,
|
||||
'top_embeddings': top_embeddings,
|
||||
'usage_timeline': timeline,
|
||||
'total_executions': usage_data.get('total_executions', 0)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_base_model_distribution(self, request: web.Request) -> web.Response:
|
||||
"""Get base model distribution statistics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count by base model
|
||||
lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data)
|
||||
checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data)
|
||||
embedding_base_models = Counter(emb.get('base_model', 'Unknown') for emb in embedding_cache.raw_data)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': dict(lora_base_models),
|
||||
'checkpoints': dict(checkpoint_base_models),
|
||||
'embeddings': dict(embedding_base_models)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting base model distribution: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_tag_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get tag usage analytics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count tag frequencies
|
||||
all_tags = []
|
||||
for lora in lora_cache.raw_data:
|
||||
all_tags.extend(lora.get('tags', []))
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
all_tags.extend(cp.get('tags', []))
|
||||
for emb in embedding_cache.raw_data:
|
||||
all_tags.extend(emb.get('tags', []))
|
||||
|
||||
tag_counts = Counter(all_tags)
|
||||
|
||||
# Get top 50 tags
|
||||
top_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.most_common(50)]
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'top_tags': top_tags,
|
||||
'total_unique_tags': len(tag_counts)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tag analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_storage_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get storage usage analytics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create models with usage data
|
||||
lora_storage = []
|
||||
for lora in lora_cache.raw_data:
|
||||
usage_count = 0
|
||||
if lora['sha256'] in usage_data.get('loras', {}):
|
||||
usage_count = usage_data['loras'][lora['sha256']].get('total', 0)
|
||||
|
||||
lora_storage.append({
|
||||
'name': lora['model_name'],
|
||||
'size': lora.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': lora.get('folder', ''),
|
||||
'base_model': lora.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
checkpoint_storage = []
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
usage_count = 0
|
||||
if cp['sha256'] in usage_data.get('checkpoints', {}):
|
||||
usage_count = usage_data['checkpoints'][cp['sha256']].get('total', 0)
|
||||
|
||||
checkpoint_storage.append({
|
||||
'name': cp['model_name'],
|
||||
'size': cp.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': cp.get('folder', ''),
|
||||
'base_model': cp.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
embedding_storage = []
|
||||
for emb in embedding_cache.raw_data:
|
||||
usage_count = 0
|
||||
if emb['sha256'] in usage_data.get('embeddings', {}):
|
||||
usage_count = usage_data['embeddings'][emb['sha256']].get('total', 0)
|
||||
|
||||
embedding_storage.append({
|
||||
'name': emb['model_name'],
|
||||
'size': emb.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': emb.get('folder', ''),
|
||||
'base_model': emb.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
# Sort by size
|
||||
lora_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
checkpoint_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
embedding_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': lora_storage[:20], # Top 20 by size
|
||||
'checkpoints': checkpoint_storage[:20],
|
||||
'embeddings': embedding_storage[:20]
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_insights(self, request: web.Request) -> web.Response:
|
||||
"""Get smart insights about the collection"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
insights = []
|
||||
|
||||
# Calculate unused models
|
||||
unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {}))
|
||||
unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
|
||||
unused_embeddings = self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
|
||||
total_loras = len(lora_cache.raw_data)
|
||||
total_checkpoints = len(checkpoint_cache.raw_data)
|
||||
total_embeddings = len(embedding_cache.raw_data)
|
||||
|
||||
if total_loras > 0:
|
||||
unused_lora_percent = (unused_loras / total_loras) * 100
|
||||
if unused_lora_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused LoRAs',
|
||||
'description': f'{unused_lora_percent:.1f}% of your LoRAs ({unused_loras}/{total_loras}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused models to free up storage space.'
|
||||
})
|
||||
|
||||
if total_checkpoints > 0:
|
||||
unused_checkpoint_percent = (unused_checkpoints / total_checkpoints) * 100
|
||||
if unused_checkpoint_percent > 30:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'Unused Checkpoints Detected',
|
||||
'description': f'{unused_checkpoint_percent:.1f}% of your checkpoints ({unused_checkpoints}/{total_checkpoints}) have never been used.',
|
||||
'suggestion': 'Review and consider removing checkpoints you no longer need.'
|
||||
})
|
||||
|
||||
if total_embeddings > 0:
|
||||
unused_embedding_percent = (unused_embeddings / total_embeddings) * 100
|
||||
if unused_embedding_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused Embeddings',
|
||||
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
|
||||
})
|
||||
|
||||
# Storage insights
|
||||
total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \
|
||||
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) + \
|
||||
sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
if total_size > 100 * 1024 * 1024 * 1024: # 100GB
|
||||
insights.append({
|
||||
'type': 'info',
|
||||
'title': 'Large Collection Detected',
|
||||
'description': f'Your model collection is using {self._format_size(total_size)} of storage.',
|
||||
'suggestion': 'Consider using external storage or cloud solutions for better organization.'
|
||||
})
|
||||
|
||||
# Recent activity insight
|
||||
if usage_data.get('total_executions', 0) > 100:
|
||||
insights.append({
|
||||
'type': 'success',
|
||||
'title': 'Active User',
|
||||
'description': f'You\'ve completed {usage_data["total_executions"]} generations so far!',
|
||||
'suggestion': 'Keep exploring and creating amazing content with your models.'
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'insights': insights
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting insights: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
def _count_unused_models(self, models: List[Dict], usage_data: Dict) -> int:
|
||||
"""Count models that have never been used"""
|
||||
used_hashes = set(usage_data.keys())
|
||||
unused_count = 0
|
||||
|
||||
for model in models:
|
||||
if model.get('sha256') not in used_hashes:
|
||||
unused_count += 1
|
||||
|
||||
return unused_count
|
||||
|
||||
def _get_top_used_models(self, usage_data: Dict, model_map: Dict, limit: int) -> List[Dict]:
|
||||
"""Get top used models with their metadata"""
|
||||
sorted_usage = sorted(usage_data.items(), key=lambda x: x[1].get('total', 0), reverse=True)
|
||||
|
||||
top_models = []
|
||||
for sha256, usage_info in sorted_usage[:limit]:
|
||||
if sha256 in model_map:
|
||||
model = model_map[sha256]
|
||||
top_models.append({
|
||||
'name': model['model_name'],
|
||||
'usage_count': usage_info.get('total', 0),
|
||||
'base_model': model.get('base_model', 'Unknown'),
|
||||
'preview_url': config.get_preview_static_url(model.get('preview_url', '')),
|
||||
'folder': model.get('folder', '')
|
||||
})
|
||||
|
||||
return top_models
|
||||
|
||||
def _get_usage_timeline(self, usage_data: Dict, days: int) -> List[Dict]:
|
||||
"""Get usage timeline for the past N days"""
|
||||
timeline = []
|
||||
today = datetime.now()
|
||||
|
||||
for i in range(days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.strftime('%Y-%m-%d')
|
||||
|
||||
lora_usage = 0
|
||||
checkpoint_usage = 0
|
||||
embedding_usage = 0
|
||||
|
||||
# Count usage for this date
|
||||
for model_usage in usage_data.get('loras', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
lora_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('checkpoints', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
checkpoint_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('embeddings', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
embedding_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
timeline.append({
|
||||
'date': date_str,
|
||||
'lora_usage': lora_usage,
|
||||
'checkpoint_usage': checkpoint_usage,
|
||||
'embedding_usage': embedding_usage,
|
||||
'total_usage': lora_usage + checkpoint_usage + embedding_usage
|
||||
})
|
||||
|
||||
return list(reversed(timeline)) # Oldest to newest
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f} PB"
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
|
||||
# Register page route
|
||||
app.router.add_get('/statistics', self.handle_stats_page)
|
||||
|
||||
# Register API routes
|
||||
app.router.add_get('/api/lm/stats/collection-overview', self.get_collection_overview)
|
||||
app.router.add_get('/api/lm/stats/usage-analytics', self.get_usage_analytics)
|
||||
app.router.add_get('/api/lm/stats/base-model-distribution', self.get_base_model_distribution)
|
||||
app.router.add_get('/api/lm/stats/tag-analytics', self.get_tag_analytics)
|
||||
app.router.add_get('/api/lm/stats/storage-analytics', self.get_storage_analytics)
|
||||
app.router.add_get('/api/lm/stats/insights', self.get_insights)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
@@ -1,9 +1,13 @@
|
||||
import os
|
||||
import aiohttp
|
||||
import logging
|
||||
import toml
|
||||
import git
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, List
|
||||
from ..services.downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,7 +17,9 @@ class UpdateRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app):
|
||||
"""Register update check routes"""
|
||||
app.router.add_get('/loras/api/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/lm/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/lm/version-info', UpdateRoutes.get_version_info)
|
||||
app.router.add_post('/api/lm/perform-update', UpdateRoutes.perform_update)
|
||||
|
||||
@staticmethod
|
||||
async def check_updates(request):
|
||||
@@ -22,28 +28,39 @@ class UpdateRoutes:
|
||||
Returns update status and version information
|
||||
"""
|
||||
try:
|
||||
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version()
|
||||
logger.info(f"Local version: {local_version}")
|
||||
|
||||
# Get git info (commit hash, branch)
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
|
||||
# Fetch remote version from GitHub
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
logger.info(f"Remote version: {remote_version}")
|
||||
if nightly:
|
||||
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||
else:
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
|
||||
# Compare versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
logger.info(f"Update available: {update_available}")
|
||||
if nightly:
|
||||
# For nightly, compare commit hashes
|
||||
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||
else:
|
||||
# For stable, compare semantic versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'current_version': local_version,
|
||||
'latest_version': remote_version,
|
||||
'update_available': update_available,
|
||||
'changelog': changelog
|
||||
'changelog': changelog,
|
||||
'git_info': git_info,
|
||||
'nightly': nightly
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -52,6 +69,301 @@ class UpdateRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def get_version_info(request):
|
||||
"""
|
||||
Returns the current version in the format 'version-short_hash'
|
||||
"""
|
||||
try:
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version().replace('v', '')
|
||||
|
||||
# Get git info (commit hash, branch)
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
short_hash = git_info['short_hash']
|
||||
|
||||
# Format: version-short_hash
|
||||
version_string = f"{local_version}-{short_hash}"
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'version': version_string
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get version info: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def perform_update(request):
|
||||
"""
|
||||
Perform Git-based update to latest release tag or main branch.
|
||||
If .git is missing, fallback to ZIP download.
|
||||
"""
|
||||
try:
|
||||
body = await request.json() if request.has_body else {}
|
||||
nightly = body.get('nightly', False)
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
settings_path = os.path.join(plugin_root, 'settings.json')
|
||||
settings_backup = None
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings_backup = f.read()
|
||||
logger.info("Backed up settings.json")
|
||||
|
||||
git_folder = os.path.join(plugin_root, '.git')
|
||||
if os.path.exists(git_folder):
|
||||
# Git update
|
||||
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||
else:
|
||||
# Fallback: Download ZIP and replace files
|
||||
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
|
||||
|
||||
if settings_backup and success:
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
f.write(settings_backup)
|
||||
logger.info("Restored settings.json")
|
||||
|
||||
if success:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully updated to {new_version}',
|
||||
'new_version': new_version
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to complete update'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download latest release ZIP from GitHub and replace plugin files.
|
||||
Skips settings.json and civitai folder. Writes extracted file list to .tracking.
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Get release info
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
github_api,
|
||||
use_auth=False
|
||||
)
|
||||
if not success:
|
||||
logger.error(f"Failed to fetch release info: {data}")
|
||||
return False, ""
|
||||
|
||||
zip_url = data.get("zipball_url")
|
||||
version = data.get("tag_name", "unknown")
|
||||
|
||||
# Download ZIP to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
tmp_zip_path = tmp_zip.name
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
url=zip_url,
|
||||
save_path=tmp_zip_path,
|
||||
use_auth=False,
|
||||
allow_resume=False
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"Failed to download ZIP: {result}")
|
||||
return False, ""
|
||||
|
||||
zip_path = tmp_zip_path
|
||||
|
||||
# Skip both settings.json and civitai folder
|
||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai'])
|
||||
|
||||
# Extract ZIP to temp dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
# Find extracted folder (GitHub ZIP contains a root folder)
|
||||
extracted_root = next(os.scandir(tmp_dir)).path
|
||||
|
||||
# Copy files, skipping settings.json and civitai folder
|
||||
for item in os.listdir(extracted_root):
|
||||
if item == 'settings.json' or item == 'civitai':
|
||||
continue
|
||||
src = os.path.join(extracted_root, item)
|
||||
dst = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(src):
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Write .tracking file: list all files under extracted_root, relative to extracted_root
|
||||
# for ComfyUI Manager to work properly
|
||||
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
||||
tracking_files = []
|
||||
for root, dirs, files in os.walk(extracted_root):
|
||||
# Skip civitai folder and its contents
|
||||
rel_root = os.path.relpath(root, extracted_root)
|
||||
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
|
||||
continue
|
||||
for file in files:
|
||||
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
||||
# Skip settings.json and any file under civitai
|
||||
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
|
||||
continue
|
||||
tracking_files.append(rel_path.replace("\\", "/"))
|
||||
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
||||
file.write('\n'.join(tracking_files))
|
||||
|
||||
os.remove(zip_path)
|
||||
logger.info(f"Updated plugin via ZIP to {version}")
|
||||
return True, version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
||||
return False, ""
|
||||
|
||||
def _clean_plugin_folder(plugin_root, skip_files=None):
|
||||
skip_files = skip_files or []
|
||||
for item in os.listdir(plugin_root):
|
||||
if item in skip_files:
|
||||
continue
|
||||
path = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
os.remove(path)
|
||||
|
||||
@staticmethod
|
||||
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
Fetch latest commit from main branch
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
|
||||
# Use GitHub API to fetch the latest commit from main branch
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch GitHub commit: {data}")
|
||||
return "main", []
|
||||
|
||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||
commit_message = data.get('commit', {}).get('message', '')
|
||||
|
||||
# Format as "main-{short_hash}"
|
||||
version = f"main-{commit_sha}"
|
||||
|
||||
# Use commit message as changelog
|
||||
changelog = [commit_message] if commit_message else []
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||
return "main", []
|
||||
|
||||
@staticmethod
|
||||
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||
"""
|
||||
Compare local commit hash with remote main branch
|
||||
"""
|
||||
try:
|
||||
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||
if local_hash == 'unknown':
|
||||
return True # Assume update available if we can't get local hash
|
||||
|
||||
# Extract remote hash from version string (format: "main-{hash}")
|
||||
if '-' in remote_version:
|
||||
remote_hash = remote_version.split('-')[-1]
|
||||
return local_hash != remote_hash
|
||||
|
||||
return True # Default to update available
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing nightly versions: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform Git-based update using GitPython
|
||||
|
||||
Args:
|
||||
plugin_root: Path to the plugin root directory
|
||||
nightly: Whether to update to main branch or latest release
|
||||
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
|
||||
# Fetch latest changes
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
|
||||
if nightly:
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
if main_branch not in [branch.name for branch in repo.branches]:
|
||||
# Create local main branch if it doesn't exist
|
||||
repo.create_head(main_branch, origin.refs.main)
|
||||
|
||||
repo.heads[main_branch].checkout()
|
||||
origin.pull(main_branch)
|
||||
|
||||
# Get new commit hash
|
||||
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||
|
||||
else:
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
if not tags:
|
||||
logger.error("No tags found in repository")
|
||||
return False, ""
|
||||
|
||||
latest_tag = tags[0]
|
||||
|
||||
# Checkout to latest tag
|
||||
repo.git.checkout(latest_tag.name)
|
||||
|
||||
new_version = latest_tag.name
|
||||
|
||||
logger.info(f"Successfully updated to {new_version}")
|
||||
return True, new_version
|
||||
|
||||
except git.exc.GitError as e:
|
||||
logger.error(f"Git error during update: {e}")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Git update: {e}")
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_local_version() -> str:
|
||||
@@ -76,6 +388,35 @@ class UpdateRoutes:
|
||||
logger.error(f"Failed to get local version: {e}", exc_info=True)
|
||||
return "v0.0.0"
|
||||
|
||||
@staticmethod
|
||||
def _get_git_info() -> Dict[str, str]:
|
||||
"""Get Git repository information"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
git_info = {
|
||||
'commit_hash': 'unknown',
|
||||
'short_hash': 'stable',
|
||||
'branch': 'unknown',
|
||||
'commit_date': 'unknown'
|
||||
}
|
||||
|
||||
try:
|
||||
# Check if we're in a git repository
|
||||
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
||||
return git_info
|
||||
|
||||
repo = git.Repo(plugin_root)
|
||||
commit = repo.head.commit
|
||||
git_info['commit_hash'] = commit.hexsha
|
||||
git_info['short_hash'] = commit.hexsha[:7]
|
||||
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
|
||||
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting git info: {e}")
|
||||
|
||||
return git_info
|
||||
|
||||
@staticmethod
|
||||
async def _get_remote_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
@@ -90,22 +431,22 @@ class UpdateRoutes:
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch GitHub release: {response.status}")
|
||||
return "v0.0.0", []
|
||||
|
||||
data = await response.json()
|
||||
version = data.get('tag_name', '')
|
||||
if not version.startswith('v'):
|
||||
version = f"v{version}"
|
||||
|
||||
# Extract changelog from release notes
|
||||
body = data.get('body', '')
|
||||
changelog = UpdateRoutes._parse_changelog(body)
|
||||
|
||||
return version, changelog
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch GitHub release: {data}")
|
||||
return "v0.0.0", []
|
||||
|
||||
version = data.get('tag_name', '')
|
||||
if not version.startswith('v'):
|
||||
version = f"v{version}"
|
||||
|
||||
# Extract changelog from release notes
|
||||
body = data.get('body', '')
|
||||
changelog = UpdateRoutes._parse_changelog(body)
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
||||
@@ -154,11 +495,16 @@ class UpdateRoutes:
|
||||
"""
|
||||
Compare two semantic version strings
|
||||
Returns True if version2 is newer than version1
|
||||
Ignores any suffixes after '-' (e.g., -bugfix, -alpha)
|
||||
"""
|
||||
try:
|
||||
# Clean version strings - remove any suffix after '-'
|
||||
v1_clean = version1.split('-')[0]
|
||||
v2_clean = version2.split('-')[0]
|
||||
|
||||
# Split versions into components
|
||||
v1_parts = [int(x) for x in version1.split('.')]
|
||||
v2_parts = [int(x) for x in version2.split('.')]
|
||||
v1_parts = [int(x) for x in v1_clean.split('.')]
|
||||
v2_parts = [int(x) for x in v2_clean.split('.')]
|
||||
|
||||
# Ensure both have 3 components (major.minor.patch)
|
||||
while len(v1_parts) < 3:
|
||||
|
||||
375
py/services/base_model_service.py
Normal file
375
py/services/base_model_service.py
Normal file
@@ -0,0 +1,375 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||
from .settings_manager import settings as default_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
scanner,
|
||||
metadata_class: Type[BaseModelMetadata],
|
||||
*,
|
||||
cache_repository: Optional[ModelCacheRepository] = None,
|
||||
filter_set: Optional[ModelFilterSet] = None,
|
||||
search_strategy: Optional[SearchStrategy] = None,
|
||||
settings_provider: Optional[SettingsProvider] = None,
|
||||
):
|
||||
"""Initialize the service.
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.).
|
||||
scanner: Model scanner instance.
|
||||
metadata_class: Metadata class for this model type.
|
||||
cache_repository: Custom repository for cache access (primarily for tests).
|
||||
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||
search_strategy: Search component for fuzzy/text matching.
|
||||
settings_provider: Settings object; defaults to the global settings manager.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
self.settings = settings_provider or default_settings
|
||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||
self.search_strategy = search_strategy or SearchStrategy()
|
||||
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_by: str = 'name',
|
||||
folder: str = None,
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated and filtered model data"""
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
filtered_data = await self._apply_common_filters(
|
||||
sorted_data,
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
)
|
||||
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data,
|
||||
search,
|
||||
fuzzy_search,
|
||||
search_options,
|
||||
)
|
||||
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
criteria = FilterCriteria(
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
)
|
||||
return self.filter_set.apply(data, criteria)
|
||||
|
||||
async def _apply_search_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
search: str,
|
||||
fuzzy_search: bool,
|
||||
search_options: dict,
|
||||
) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images", "customImages", "creator"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||
"""Get hierarchical folder tree for a specific model root"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build tree structure from folders
|
||||
tree = {}
|
||||
|
||||
for folder in cache.folders:
|
||||
# Check if this folder belongs to the specified model root
|
||||
folder_belongs_to_root = False
|
||||
for root in self.scanner.get_model_roots():
|
||||
if root == model_root:
|
||||
folder_belongs_to_root = True
|
||||
break
|
||||
|
||||
if not folder_belongs_to_root:
|
||||
continue
|
||||
|
||||
# Split folder path into components
|
||||
parts = folder.split('/') if folder else []
|
||||
current_level = tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return tree
|
||||
|
||||
async def get_unified_folder_tree(self) -> Dict:
|
||||
"""Get unified folder tree across all model roots"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build unified tree structure by analyzing all relative paths
|
||||
unified_tree = {}
|
||||
|
||||
# Get all model roots for path normalization
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for folder in cache.folders:
|
||||
if not folder: # Skip empty folders
|
||||
continue
|
||||
|
||||
# Find which root this folder belongs to by checking the actual file paths
|
||||
# This is a simplified approach - we'll use the folder as-is since it should already be relative
|
||||
relative_path = folder
|
||||
|
||||
# Split folder path into components
|
||||
parts = relative_path.split('/')
|
||||
current_level = unified_tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return unified_tree
|
||||
|
||||
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
return model.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
preview_url = model.get('preview_url')
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return '/loras_static/images/no-preview.png'
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
civitai_data = model.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||
"""Get filtered CivitAI metadata for a model by file path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return self.filter_civitai_data(model.get("civitai", {}))
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||
"""Get model description by file path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return model.get('modelDescription', '')
|
||||
|
||||
return None
|
||||
|
||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
matching_paths = []
|
||||
search_lower = search_term.lower()
|
||||
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get('file_path', '')
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Calculate relative path from model root
|
||||
relative_path = None
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root)
|
||||
normalized_file = os.path.normpath(file_path)
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
# Remove root and leading separator to get relative path
|
||||
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
|
||||
break
|
||||
|
||||
if relative_path and search_lower in relative_path.lower():
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
# Sort by relevance (exact matches first, then by length)
|
||||
matching_paths.sort(key=lambda x: (
|
||||
not x.lower().startswith(search_lower), # Exact prefix matches first
|
||||
len(x), # Then by length (shorter first)
|
||||
x.lower() # Then alphabetically
|
||||
))
|
||||
|
||||
return matching_paths[:limit]
|
||||
34
py/services/checkpoint_scanner.py
Normal file
34
py/services/checkpoint_scanner.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
if hasattr(metadata, "model_type"):
|
||||
if root_path in config.checkpoints_roots:
|
||||
metadata.model_type = "checkpoint"
|
||||
elif root_path in config.unet_roots:
|
||||
metadata.model_type = "diffusion_model"
|
||||
return metadata
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return config.base_models_roots
|
||||
49
py/services/checkpoint_service.py
Normal file
49
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,66 +1,40 @@
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import logging
|
||||
from email.parser import Parser
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from urllib.parse import unquote
|
||||
from ..utils.models import LoraMetadata
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CivitaiClient:
|
||||
def __init__(self):
|
||||
self.base_url = "https://civitai.com/api/v1"
|
||||
self.headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||
}
|
||||
self._session = None
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Lazy initialize the session"""
|
||||
if self._session is None:
|
||||
connector = aiohttp.TCPConnector(ssl=True)
|
||||
trust_env = True # 允许使用系统环境变量中的代理设置
|
||||
self._session = aiohttp.ClientSession(connector=connector, trust_env=trust_env)
|
||||
return self._session
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of CivitaiClient"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
|
||||
# Register this client as a metadata provider
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _parse_content_disposition(self, header: str) -> str:
|
||||
"""Parse filename from content-disposition header"""
|
||||
if not header:
|
||||
return None
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# Handle quoted filenames
|
||||
if 'filename="' in header:
|
||||
start = header.index('filename="') + 10
|
||||
end = header.index('"', start)
|
||||
return unquote(header[start:end])
|
||||
|
||||
# Fallback to original parsing
|
||||
disposition = Parser().parsestr(f'Content-Disposition: {header}')
|
||||
filename = disposition.get_param('filename')
|
||||
if filename:
|
||||
return unquote(filename)
|
||||
return None
|
||||
|
||||
def _get_request_headers(self) -> dict:
|
||||
"""Get request headers with optional API key"""
|
||||
headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
from .settings_manager import settings
|
||||
api_key = settings.get('civitai_api_key')
|
||||
if (api_key):
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
return headers
|
||||
|
||||
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with content-disposition support and progress tracking
|
||||
self.base_url = "https://civitai.com/api/v1"
|
||||
|
||||
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
@@ -71,146 +45,302 @@ class CivitaiClient:
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
session = await self.session
|
||||
downloader = await get_downloader()
|
||||
save_path = os.path.join(save_dir, default_filename)
|
||||
|
||||
# Use unified downloader with CivitAI authentication
|
||||
success, result = await downloader.download_file(
|
||||
url=url,
|
||||
save_path=save_path,
|
||||
progress_callback=progress_callback,
|
||||
use_auth=True, # Enable CivitAI authentication
|
||||
allow_resume=True
|
||||
)
|
||||
|
||||
return success, result
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
try:
|
||||
headers = self._get_request_headers()
|
||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
||||
if response.status != 200:
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
# Get filename from content-disposition header
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
filename = self._parse_content_disposition(content_disposition)
|
||||
if not filename:
|
||||
filename = default_filename
|
||||
|
||||
save_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Get total file size for progress calculation
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
current_size = 0
|
||||
|
||||
# Stream download to file with progress updates
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
current_size += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Get model ID from version data
|
||||
model_id = result.get('modelId')
|
||||
if model_id:
|
||||
# Fetch additional model metadata
|
||||
success_model, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success_model:
|
||||
# Enrich version_info with model data
|
||||
result['model']['description'] = data.get("description")
|
||||
result['model']['tags'] = data.get("tags", [])
|
||||
|
||||
return True, save_path
|
||||
# Add creator from model data
|
||||
result['creator'] = data.get("creator")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
try:
|
||||
session = await self.session
|
||||
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
return None, "Model not found"
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
||||
return None, str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"API Error: {str(e)}")
|
||||
return None
|
||||
return None, str(e)
|
||||
|
||||
async def download_preview_image(self, image_url: str, save_path: str):
|
||||
try:
|
||||
session = await self.session
|
||||
async with session.get(image_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
image_url,
|
||||
use_auth=False # Preview images don't need auth
|
||||
)
|
||||
if success:
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Download Error: {str(e)}")
|
||||
logger.error(f"Download Error: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
||||
"""Get all versions of a model with local availability info"""
|
||||
try:
|
||||
session = await self.session # 等待获取 session
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
data = await response.json()
|
||||
return data.get('modelVersions', [])
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Also return model type along with versions
|
||||
return {
|
||||
'modelVersions': result.get('modelVersions', []),
|
||||
'type': result.get('type', ''),
|
||||
'name': result.get('name', '')
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
|
||||
"""Fetch model version metadata from Civitai"""
|
||||
try:
|
||||
session = await self.session
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
headers = self._get_request_headers()
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version info: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Optional[Dict]:
|
||||
"""Fetch model metadata (description and tags) from Civitai API
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
model_id: The Civitai model ID (optional if version_id is provided)
|
||||
version_id: Optional specific version ID to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: A dictionary containing model metadata or None if not found
|
||||
Optional[Dict]: The model version data with additional fields or None if not found
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
headers = self._get_request_headers()
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
downloader = await get_downloader()
|
||||
|
||||
logger.info(f"Fetching model metadata from {url}")
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch model metadata: Status {response.status}")
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
# Now get the model data for additional metadata
|
||||
success, model_data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Enrich version with model data
|
||||
version['model']['description'] = model_data.get("description")
|
||||
version['model']['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": data.get("description", ""),
|
||||
"tags": data.get("tags", [])
|
||||
}
|
||||
|
||||
if metadata["description"] or metadata["tags"]:
|
||||
logger.info(f"Successfully retrieved metadata for model {model_id}")
|
||||
return metadata
|
||||
return version
|
||||
|
||||
# Case 2: model_id is provided (with or without version_id)
|
||||
elif model_id is not None:
|
||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
model_versions = data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Step 2: Determine the target version entry to use
|
||||
target_version = None
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
if target_version is None:
|
||||
target_version = model_versions[0]
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
|
||||
# Step 3: Get detailed version info using the SHA256 hash
|
||||
model_hash = None
|
||||
for file_info in target_version.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
model_hash = file_info.get('hashes', {}).get('SHA256')
|
||||
if model_hash:
|
||||
break
|
||||
|
||||
version = None
|
||||
if model_hash:
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
|
||||
)
|
||||
version = None
|
||||
else:
|
||||
logger.warning(f"No metadata found for model {model_id}")
|
||||
return None
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = copy.deepcopy(target_version)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': data.get('name'),
|
||||
'type': data.get('type'),
|
||||
'nsfw': data.get('nsfw'),
|
||||
'poi': data.get('poi')
|
||||
}
|
||||
|
||||
# Step 4: Enrich version_info with model data
|
||||
# Add description and tags from model data
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
model_info['description'] = data.get("description")
|
||||
model_info['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
|
||||
return version
|
||||
|
||||
# Case 3: Neither model_id nor version_id provided
|
||||
else:
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
return None
|
||||
|
||||
# Keep old method for backward compatibility, delegating to the new one
|
||||
async def get_model_description(self, model_id: str) -> Optional[str]:
|
||||
"""Fetch the model description from Civitai API (Legacy method)"""
|
||||
metadata = await self.get_model_metadata(model_id)
|
||||
return metadata.get("description") if metadata else None
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from Civitai
|
||||
|
||||
Args:
|
||||
version_id: The Civitai model version ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
|
||||
- The model version data or None if not found
|
||||
- An error message if there was an error, or None on success
|
||||
"""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
|
||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
return None, error_msg
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
||||
return None, str(result)
|
||||
except Exception as e:
|
||||
error_msg = f"Error fetching model version info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None, error_msg
|
||||
|
||||
async def close(self):
|
||||
"""Close the session if it exists"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
Args:
|
||||
image_id: The Civitai image ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The image data or None if not found
|
||||
"""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||
|
||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if success:
|
||||
if result and "items" in result and len(result["items"]) > 0:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return result["items"][0]
|
||||
logger.warning(f"No image found with ID: {image_id}")
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
100
py/services/download_coordinator.py
Normal file
100
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Service wrapper for coordinating download lifecycle events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadCoordinator:
|
||||
"""Manage download scheduling, cancellation and introspection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager_factory: Callable[[], Awaitable],
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._download_manager_factory = download_manager_factory
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download using the provided payload."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
|
||||
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||
payload.setdefault("download_id", download_id)
|
||||
|
||||
async def progress_callback(progress: Any) -> None:
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "progress",
|
||||
"progress": progress,
|
||||
"download_id": download_id,
|
||||
},
|
||||
)
|
||||
|
||||
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||
model_version_id = self._parse_optional_int(
|
||||
payload.get("model_version_id"), "model_version_id"
|
||||
)
|
||||
|
||||
if model_id is None and model_version_id is None:
|
||||
raise ValueError(
|
||||
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||
)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=payload.get("model_root"),
|
||||
relative_path=payload.get("relative_path", ""),
|
||||
use_default_paths=payload.get("use_default_paths", False),
|
||||
progress_callback=progress_callback,
|
||||
download_id=download_id,
|
||||
source=payload.get("source"),
|
||||
)
|
||||
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Cancel an active download and emit a broadcast event."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "cancelled",
|
||||
"progress": 0,
|
||||
"download_id": download_id,
|
||||
"message": "Download cancelled by user",
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||
"""Return the active download map from the underlying manager."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
return await download_manager.get_active_downloads()
|
||||
|
||||
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||
"""Parse an optional integer from user input."""
|
||||
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||
|
||||
@@ -1,79 +1,435 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict
|
||||
from .civitai_client import CivitaiClient
|
||||
from .file_monitor import LoraFileMonitor
|
||||
from ..utils.models import LoraMetadata
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from .downloader import get_downloader
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.file_monitor = file_monitor
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of DownloadManager"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '',
|
||||
progress_callback=None) -> Dict:
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# Add download management
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
return await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
async def _get_checkpoint_scanner(self):
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
||||
save_dir: str = None, relative_path: str = '',
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
download_id: str = None, source: str = None) -> Dict:
|
||||
"""Download model from Civitai with task tracking and concurrency control
|
||||
|
||||
Args:
|
||||
model_id: Civitai model ID (optional if model_version_id is provided)
|
||||
model_version_id: Civitai model version ID (optional if model_id is provided)
|
||||
save_dir: Directory to save the model
|
||||
relative_path: Relative path within save_dir
|
||||
progress_callback: Callback function for progress updates
|
||||
use_default_paths: Flag to use default paths
|
||||
download_id: Unique identifier for this download task
|
||||
source: Optional source parameter to specify metadata provider
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
"""
|
||||
# Validate that at least one identifier is provided
|
||||
if not model_id and not model_version_id:
|
||||
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
|
||||
|
||||
# Use provided download_id or generate new one
|
||||
task_id = download_id or str(uuid.uuid4())
|
||||
|
||||
# Register download task in tracking dict
|
||||
self._active_downloads[task_id] = {
|
||||
'model_id': model_id,
|
||||
'model_version_id': model_version_id,
|
||||
'progress': 0,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# Create tracking task
|
||||
download_task = asyncio.create_task(
|
||||
self._download_with_semaphore(
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths, source
|
||||
)
|
||||
)
|
||||
|
||||
# Store task for tracking and cancellation
|
||||
self._download_tasks[task_id] = download_task
|
||||
|
||||
try:
|
||||
# Wait for download to complete
|
||||
result = await download_task
|
||||
result['download_id'] = task_id # Include download_id in result
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
|
||||
finally:
|
||||
# Clean up task reference
|
||||
if task_id in self._download_tasks:
|
||||
del self._download_tasks[task_id]
|
||||
|
||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||
save_dir: str, relative_path: str,
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
source: str = None):
|
||||
"""Execute download with semaphore to limit concurrency"""
|
||||
# Update status to waiting
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'waiting'
|
||||
|
||||
# Wrap progress callback to track progress in active_downloads
|
||||
original_callback = progress_callback
|
||||
async def tracking_callback(progress):
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['progress'] = progress
|
||||
if original_callback:
|
||||
await original_callback(progress)
|
||||
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
async with self._download_semaphore:
|
||||
# Update status to downloading
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'downloading'
|
||||
|
||||
# Use original download implementation
|
||||
try:
|
||||
# Check for cancellation before starting
|
||||
if asyncio.current_task().cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
result = await self._execute_original_download(
|
||||
model_id, model_version_id, save_dir,
|
||||
relative_path, tracking_callback, use_default_paths,
|
||||
task_id, source
|
||||
)
|
||||
|
||||
# Update status based on result
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
||||
if not result['success']:
|
||||
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
||||
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'cancelled'
|
||||
logger.info(f"Download cancelled for task {task_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle other errors
|
||||
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'failed'
|
||||
self._active_downloads[task_id]['error'] = str(e)
|
||||
return {'success': False, 'error': str(e)}
|
||||
finally:
|
||||
# Schedule cleanup of download record after delay
|
||||
asyncio.create_task(self._cleanup_download_record(task_id))
|
||||
|
||||
async def _cleanup_download_record(self, task_id: str):
|
||||
"""Keep completed downloads in history for a short time"""
|
||||
await asyncio.sleep(600) # Keep for 10 minutes
|
||||
if task_id in self._active_downloads:
|
||||
del self._active_downloads[task_id]
|
||||
|
||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths,
|
||||
download_id=None, source=None):
|
||||
"""Wrapper for original download_from_civitai implementation"""
|
||||
try:
|
||||
# Check if model version already exists in library
|
||||
if model_version_id is not None:
|
||||
# Check both scanners
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Check lora scanner first
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
|
||||
# Check checkpoint scanner
|
||||
if await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
|
||||
# Check embedding scanner
|
||||
if await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Get metadata provider based on source parameter
|
||||
if source == 'civarchive':
|
||||
from .metadata_service import get_metadata_provider
|
||||
metadata_provider = await get_metadata_provider('civarchive')
|
||||
else:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
||||
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||
if model_type_from_info == 'checkpoint':
|
||||
model_type = 'checkpoint'
|
||||
elif model_type_from_info in VALID_LORA_TYPES:
|
||||
model_type = 'lora'
|
||||
elif model_type_from_info == 'textualinversion':
|
||||
model_type = 'embedding'
|
||||
else:
|
||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||
|
||||
# Case 2: model_version_id was None, check after getting version_info
|
||||
if model_version_id is None:
|
||||
version_id = version_info.get('id')
|
||||
|
||||
if model_type == 'lora':
|
||||
# Check lora scanner
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
if await lora_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
elif model_type == 'checkpoint':
|
||||
# Check checkpoint scanner
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
if await checkpoint_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
elif model_type == 'embedding':
|
||||
# Embeddings are not checked in scanners, but we can still check if it exists
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
if await embedding_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
# Set save_dir based on model type
|
||||
if model_type == 'checkpoint':
|
||||
default_path = settings.get('default_checkpoint_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'lora':
|
||||
default_path = settings.get('default_lora_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'embedding':
|
||||
default_path = settings.get('default_embedding_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
||||
save_dir = default_path
|
||||
|
||||
# Calculate relative path using template
|
||||
relative_path = self._calculate_relative_path(version_info, model_type)
|
||||
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Get version info
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info = await self.civitai_client.get_model_version_info(version_id)
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
# Check if this is an early access model
|
||||
if version_info.get('earlyAccessEndsAt'):
|
||||
early_access_date = version_info.get('earlyAccessEndsAt', '')
|
||||
# Convert to a readable date if possible
|
||||
try:
|
||||
from datetime import datetime
|
||||
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
|
||||
formatted_date = date_obj.strftime('%Y-%m-%d')
|
||||
early_access_msg = f"This model requires payment (until {formatted_date}). "
|
||||
except:
|
||||
early_access_msg = "This model requires payment. "
|
||||
|
||||
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
|
||||
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
|
||||
|
||||
# We'll still try to download, but log a warning and prepare for potential failure
|
||||
if progress_callback:
|
||||
await progress_callback(1) # Show minimal progress to indicate we're trying
|
||||
|
||||
# Report initial progress
|
||||
if progress_callback:
|
||||
await progress_callback(0)
|
||||
|
||||
# 2. 获取文件信息
|
||||
# 2. Get file information
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||
if not file_info:
|
||||
return {'success': False, 'error': 'No primary file found in metadata'}
|
||||
if not file_info.get('downloadUrl'):
|
||||
return {'success': False, 'error': 'No download URL found for primary file'}
|
||||
|
||||
# 3. 准备下载
|
||||
# 3. Prepare download
|
||||
file_name = file_info['name']
|
||||
save_path = os.path.join(save_dir, file_name)
|
||||
file_size = file_info.get('sizeKB', 0) * 1024
|
||||
|
||||
# 4. 通知文件监控系统
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
save_path.replace(os.sep, '/'),
|
||||
file_size
|
||||
)
|
||||
|
||||
# 5. 准备元数据
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
# 5. Prepare metadata based on model type
|
||||
if model_type == "checkpoint":
|
||||
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
||||
elif model_type == "lora":
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||
elif model_type == "embedding":
|
||||
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
|
||||
# 6. 开始下载流程
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
download_url=download_url,
|
||||
download_url=file_info.get('downloadUrl', ''),
|
||||
save_dir=save_dir,
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type,
|
||||
download_id=download_id
|
||||
)
|
||||
|
||||
# If early_access_msg exists and download failed, replace error message
|
||||
if 'early_access_msg' in locals() and not result.get('success', False):
|
||||
result['error'] = early_access_msg
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
|
||||
# Check if this might be an early access error
|
||||
error_str = str(e).lower()
|
||||
if "403" in error_str or "401" in error_str or "unauthorized" in error_str or "early access" in error_str:
|
||||
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
|
||||
"""Calculate relative path using template from settings
|
||||
|
||||
Args:
|
||||
version_info: Version info from Civitai API
|
||||
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
||||
|
||||
Returns:
|
||||
Relative path string
|
||||
"""
|
||||
# Get path template from settings for specific model type
|
||||
path_template = settings.get_download_path_template(model_type)
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
return ''
|
||||
|
||||
# Get base model name
|
||||
base_model = version_info.get('baseModel', '')
|
||||
|
||||
# Get author from creator data
|
||||
creator_info = version_info.get('creator')
|
||||
if creator_info and isinstance(creator_info, dict):
|
||||
author = creator_info.get('username') or 'Anonymous'
|
||||
else:
|
||||
author = 'Anonymous'
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Get model tags
|
||||
model_tags = version_info.get('model', {}).get('tags', [])
|
||||
|
||||
# Find the first Civitai model tag that exists in model_tags
|
||||
first_tag = ''
|
||||
for civitai_tag in CIVITAI_MODEL_TAGS:
|
||||
if civitai_tag in model_tags:
|
||||
first_tag = civitai_tag
|
||||
break
|
||||
|
||||
# If no Civitai model tag found, fallback to first tag
|
||||
if not first_tag and model_tags:
|
||||
first_tag = model_tags[0]
|
||||
|
||||
# Format the template with available data
|
||||
formatted_path = path_template
|
||||
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
||||
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
||||
formatted_path = formatted_path.replace('{author}', author)
|
||||
|
||||
return formatted_path
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata: LoraMetadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None) -> Dict:
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
save_path = metadata.file_path
|
||||
# Extract original filename details
|
||||
original_filename = os.path.basename(metadata.file_path)
|
||||
base_name, extension = os.path.splitext(original_filename)
|
||||
|
||||
# Check for filename conflicts and generate unique filename if needed
|
||||
# Use the hash from metadata for conflict resolution
|
||||
def hash_provider():
|
||||
return metadata.sha256
|
||||
|
||||
unique_filename = metadata.generate_unique_filename(
|
||||
save_dir,
|
||||
base_name,
|
||||
extension,
|
||||
hash_provider=hash_provider
|
||||
)
|
||||
|
||||
# Update paths if filename changed
|
||||
if unique_filename != original_filename:
|
||||
logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'")
|
||||
save_path = os.path.join(save_dir, unique_filename)
|
||||
# Update metadata with new file path and name
|
||||
metadata.file_path = save_path.replace(os.sep, '/')
|
||||
metadata.file_name = os.path.splitext(unique_filename)[0]
|
||||
else:
|
||||
save_path = metadata.file_path
|
||||
|
||||
part_path = save_path + '.part'
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
# Store file paths in active_downloads for potential cleanup
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['file_path'] = save_path
|
||||
self._active_downloads[download_id]['part_path'] = part_path
|
||||
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
@@ -82,48 +438,121 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
await progress_callback(1) # 1% progress for starting preview download
|
||||
|
||||
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
|
||||
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
# Check if it's a video or an image
|
||||
is_video = images[0].get('type') == 'video'
|
||||
|
||||
if (is_video):
|
||||
# For videos, use .mp4 extension
|
||||
preview_ext = '.mp4'
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
|
||||
# Download video directly using downloader
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.download_file(
|
||||
images[0]['url'],
|
||||
preview_path,
|
||||
use_auth=False # Preview images typically don't need auth
|
||||
)
|
||||
if success:
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
else:
|
||||
# For images, use WebP format for better performance
|
||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Download the original image to temp path using downloader
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
images[0]['url'],
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
# Save to temp file
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(content)
|
||||
# Optimize and convert to WebP
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
# Use ExifUtils to optimize and convert the image
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=temp_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=False
|
||||
)
|
||||
|
||||
# Save the optimized image
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(optimized_data)
|
||||
|
||||
# Update metadata
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
|
||||
# Remove temporary file
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file: {e}")
|
||||
|
||||
# Report preview download completion
|
||||
if progress_callback:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
# Download model file with progress tracking
|
||||
success, result = await self.civitai_client._download_file(
|
||||
# Download model file with progress tracking using downloader
|
||||
downloader = await get_downloader()
|
||||
# Determine if the download URL is from Civitai
|
||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_dir,
|
||||
os.path.basename(save_path),
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||
)
|
||||
|
||||
if not success:
|
||||
# Clean up files on failure
|
||||
for path in [save_path, metadata_path, metadata.preview_url]:
|
||||
# Clean up files on failure, but preserve .part file for resume
|
||||
cleanup_files = [metadata_path]
|
||||
if metadata.preview_url and os.path.exists(metadata.preview_url):
|
||||
cleanup_files.append(metadata.preview_url)
|
||||
|
||||
for path in cleanup_files:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||
|
||||
# Log but don't remove .part file to allow resume
|
||||
if os.path.exists(part_path):
|
||||
logger.info(f"Preserving partial download for resume: {part_path}")
|
||||
|
||||
return {'success': False, 'error': result}
|
||||
|
||||
# 4. 更新文件信息(大小和修改时间)
|
||||
# 4. Update file information (size and modified time)
|
||||
metadata.update_file_info(save_path)
|
||||
|
||||
# 5. 最终更新元数据
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
# 5. Final metadata update
|
||||
await MetadataManager.save_metadata(save_path, metadata)
|
||||
|
||||
# 6. update lora cache
|
||||
cache = await self.file_monitor.scanner.get_cached_data()
|
||||
# 6. Update cache based on model type
|
||||
if model_type == "checkpoint":
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
elif model_type == "lora":
|
||||
scanner = await self._get_lora_scanner()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
elif model_type == "embedding":
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
logger.info(f"Updating embedding cache for {save_path}")
|
||||
|
||||
# Convert metadata to dictionary
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = relative_path
|
||||
cache.raw_data.append(metadata_dict)
|
||||
await cache.resort()
|
||||
all_folders = set(cache.folders)
|
||||
all_folders.add(relative_path)
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Add model to cache and save to disk in a single operation
|
||||
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
||||
|
||||
# Report 100% completion
|
||||
if progress_callback:
|
||||
@@ -135,10 +564,18 @@ class DownloadManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||
# Clean up partial downloads
|
||||
for path in [save_path, metadata_path]:
|
||||
# Clean up partial downloads except .part file
|
||||
cleanup_files = [metadata_path]
|
||||
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
|
||||
cleanup_files.append(metadata.preview_url)
|
||||
|
||||
for path in cleanup_files:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
||||
@@ -151,4 +588,99 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
# Scale file progress to 3-100 range (after preview download)
|
||||
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
||||
await progress_callback(round(overall_progress))
|
||||
await progress_callback(round(overall_progress))
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict:
|
||||
"""Cancel an active download by download_id
|
||||
|
||||
Args:
|
||||
download_id: The unique identifier of the download task
|
||||
|
||||
Returns:
|
||||
Dict: Status of the cancellation operation
|
||||
"""
|
||||
if download_id not in self._download_tasks:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
try:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
task.cancel()
|
||||
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||
|
||||
# Wait briefly for the task to acknowledge cancellation
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
# Clean up ALL files including .part when user cancels
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info:
|
||||
# Delete the main file
|
||||
if 'file_path' in download_info:
|
||||
file_path = download_info['file_path']
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
logger.debug(f"Deleted cancelled download: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file: {e}")
|
||||
|
||||
# Delete the .part file (only on user cancellation)
|
||||
if 'part_path' in download_info:
|
||||
part_path = download_info['part_path']
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.unlink(part_path)
|
||||
logger.debug(f"Deleted partial download: {part_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting part file: {e}")
|
||||
|
||||
# Delete metadata file if exists
|
||||
if 'file_path' in download_info:
|
||||
file_path = download_info['file_path']
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
os.unlink(metadata_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting metadata file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4)
|
||||
for preview_ext in ['.webp', '.mp4']:
|
||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||
if os.path.exists(preview_path):
|
||||
try:
|
||||
os.unlink(preview_path)
|
||||
logger.debug(f"Deleted preview file: {preview_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting preview file: {e}")
|
||||
|
||||
return {'success': True, 'message': 'Download cancelled successfully'}
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def get_active_downloads(self) -> Dict:
|
||||
"""Get information about all active downloads
|
||||
|
||||
Returns:
|
||||
Dict: List of active downloads and their status
|
||||
"""
|
||||
return {
|
||||
'downloads': [
|
||||
{
|
||||
'download_id': task_id,
|
||||
'model_id': info.get('model_id'),
|
||||
'model_version_id': info.get('model_version_id'),
|
||||
'progress': info.get('progress', 0),
|
||||
'status': info.get('status', 'unknown'),
|
||||
'error': info.get('error', None)
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
}
|
||||
539
py/services/downloader.py
Normal file
539
py/services/downloader.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
Unified download manager for all HTTP/HTTPS downloads in the application.
|
||||
|
||||
This module provides a centralized download service with:
|
||||
- Singleton pattern for global session management
|
||||
- Support for authenticated downloads (e.g., CivitAI API key)
|
||||
- Resumable downloads with automatic retry
|
||||
- Progress tracking and callbacks
|
||||
- Optimized connection pooling and timeouts
|
||||
- Unified error handling and logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Tuple, Callable, Union
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Downloader:
|
||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of Downloader"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the downloader with optimal settings"""
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# Session management
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None # Store proxy URL for current session
|
||||
|
||||
# Configuration
|
||||
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
|
||||
self.max_retries = 5
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
|
||||
# Default headers
|
||||
self.default_headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||
}
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create the global aiohttp session with optimized settings"""
|
||||
if self._session is None or self._should_refresh_session():
|
||||
await self._create_session()
|
||||
return self._session
|
||||
|
||||
@property
|
||||
def proxy_url(self) -> Optional[str]:
|
||||
"""Get the current proxy URL (initialize if needed)"""
|
||||
if not hasattr(self, '_proxy_url'):
|
||||
self._proxy_url = None
|
||||
return self._proxy_url
|
||||
|
||||
def _should_refresh_session(self) -> bool:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
return True
|
||||
|
||||
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
|
||||
return True
|
||||
|
||||
# Refresh if session is older than timeout
|
||||
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _create_session(self):
|
||||
"""Create a new aiohttp session with optimized settings"""
|
||||
# Close existing session if any
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
|
||||
# Check for app-level proxy settings
|
||||
proxy_url = None
|
||||
if settings.get('proxy_enabled', False):
|
||||
proxy_host = settings.get('proxy_host', '').strip()
|
||||
proxy_port = settings.get('proxy_port', '').strip()
|
||||
proxy_type = settings.get('proxy_type', 'http').lower()
|
||||
proxy_username = settings.get('proxy_username', '').strip()
|
||||
proxy_password = settings.get('proxy_password', '').strip()
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
# Build proxy URL
|
||||
if proxy_username and proxy_password:
|
||||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||
else:
|
||||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
logger.debug(f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}")
|
||||
logger.debug("Proxy mode: app-level proxy is active.")
|
||||
else:
|
||||
logger.debug("Proxy mode: system-level proxy (trust_env) will be used if configured in environment.")
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=8, # Concurrent connections
|
||||
ttl_dns_cache=300, # DNS cache timeout
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=None, # No total timeout for large downloads
|
||||
connect=60, # Connection timeout
|
||||
sock_read=300 # 5 minute socket read timeout
|
||||
)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=proxy_url is None, # Only use system proxy if no app-level proxy is set
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Store proxy URL for use in requests
|
||||
self._proxy_url = proxy_url
|
||||
self._session_created_at = datetime.now()
|
||||
|
||||
logger.debug("Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s", bool(proxy_url), proxy_url is None)
|
||||
|
||||
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
|
||||
"""Get headers with optional authentication"""
|
||||
headers = self.default_headers.copy()
|
||||
|
||||
if use_auth:
|
||||
# Add CivitAI API key if available
|
||||
api_key = settings.get('civitai_api_key')
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
return headers
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
progress_callback: Optional[Callable[[float], None]] = None,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
allow_resume: bool = True
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Download a file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
save_path: Full path where the file should be saved
|
||||
progress_callback: Optional callback for progress updates (0-100)
|
||||
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||
custom_headers: Additional headers to include in request
|
||||
allow_resume: Whether to support resumable downloads
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
retry_count = 0
|
||||
part_path = save_path + '.part' if allow_resume else save_path
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Get existing file size for resume
|
||||
resume_offset = 0
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
total_size = 0
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_file] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[download_file] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Add Range header for resume if we have partial data
|
||||
request_headers = headers.copy()
|
||||
if allow_resume and resume_offset > 0:
|
||||
request_headers['Range'] = f'bytes={resume_offset}-'
|
||||
|
||||
# Disable compression for better chunked downloads
|
||||
request_headers['Accept-Encoding'] = 'identity'
|
||||
|
||||
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
|
||||
if resume_offset > 0:
|
||||
logger.debug(f"Requesting range from byte {resume_offset}")
|
||||
|
||||
async with session.get(url, headers=request_headers, allow_redirects=True, proxy=self.proxy_url) as response:
|
||||
# Handle different response codes
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning("Server doesn't support range requests, restarting download")
|
||||
resume_offset = 0
|
||||
if os.path.exists(part_path):
|
||||
os.remove(part_path)
|
||||
elif response.status == 206:
|
||||
# Partial content response (resume successful)
|
||||
content_range = response.headers.get('Content-Range')
|
||||
if content_range:
|
||||
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
||||
range_parts = content_range.split('/')
|
||||
if len(range_parts) == 2:
|
||||
total_size = int(range_parts[1])
|
||||
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable - file might be complete or corrupted
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
part_size = os.path.getsize(part_path)
|
||||
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
||||
# Try to get actual file size
|
||||
head_response = await session.head(url, headers=headers, proxy=self.proxy_url)
|
||||
if head_response.status == 200:
|
||||
actual_size = int(head_response.headers.get('content-length', 0))
|
||||
if part_size == actual_size:
|
||||
# File is complete, just rename it
|
||||
if allow_resume:
|
||||
os.rename(part_path, save_path)
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
return True, save_path
|
||||
# Remove corrupted part file and restart
|
||||
os.remove(part_path)
|
||||
resume_offset = 0
|
||||
continue
|
||||
elif response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
return False, "Invalid or missing API key, or early access restriction."
|
||||
elif response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
elif response.status == 404:
|
||||
logger.warning(f"Resource not found: {url} (Status 404)")
|
||||
return False, "File not found - the download link may be invalid or expired."
|
||||
else:
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
# Get total file size for progress calculation (if not set from Content-Range)
|
||||
if total_size == 0:
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
if response.status == 206:
|
||||
# For partial content, add the offset to get total file size
|
||||
total_size += resume_offset
|
||||
|
||||
current_size = resume_offset
|
||||
last_progress_report_time = datetime.now()
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
# Stream download to file with progress updates
|
||||
loop = asyncio.get_running_loop()
|
||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||
with open(part_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if chunk:
|
||||
# Run blocking file write in executor
|
||||
await loop.run_in_executor(None, f.write, chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 1.0:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size if total_size was provided
|
||||
final_size = os.path.getsize(part_path)
|
||||
if total_size > 0 and final_size != total_size:
|
||||
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
|
||||
# Don't treat this as fatal error, continue anyway
|
||||
|
||||
# Atomically rename .part to final file (only if using resume)
|
||||
if allow_resume and part_path != save_path:
|
||||
max_rename_attempts = 5
|
||||
rename_attempt = 0
|
||||
rename_success = False
|
||||
|
||||
while rename_attempt < max_rename_attempts and not rename_success:
|
||||
try:
|
||||
# If the destination file exists, remove it first (Windows safe)
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
os.rename(part_path, save_path)
|
||||
rename_success = True
|
||||
except PermissionError as e:
|
||||
rename_attempt += 1
|
||||
if rename_attempt < max_rename_attempts:
|
||||
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||
return False, f"Failed to finalize download: {str(e)}"
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return True, save_path
|
||||
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
||||
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Update resume offset for next attempt
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Will resume from byte {resume_offset}")
|
||||
|
||||
# Refresh session to get new connection
|
||||
await self._create_session()
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for download: {e}")
|
||||
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
return False, f"Download failed after {self.max_retries + 1} attempts"
|
||||
|
||||
async def download_to_memory(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
return_headers: bool = False
|
||||
) -> Tuple[bool, Union[bytes, str], Optional[Dict]]:
|
||||
"""
|
||||
Download a file to memory (for small files like preview images)
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
return_headers: Whether to return response headers along with content
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_to_memory] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[download_to_memory] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.get(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
if return_headers:
|
||||
return True, content, dict(response.headers)
|
||||
else:
|
||||
return True, content, None
|
||||
elif response.status == 401:
|
||||
error_msg = "Unauthorized access - invalid or missing API key"
|
||||
return False, error_msg, None
|
||||
elif response.status == 403:
|
||||
error_msg = "Access forbidden"
|
||||
return False, error_msg, None
|
||||
elif response.status == 404:
|
||||
error_msg = "File not found"
|
||||
return False, error_msg, None
|
||||
else:
|
||||
error_msg = f"Download failed with status {response.status}"
|
||||
return False, error_msg, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||
return False, str(e), None
|
||||
|
||||
async def get_response_headers(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Get response headers without downloading the full content
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[get_response_headers] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[get_response_headers] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.head(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
if response.status == 200:
|
||||
return True, dict(response.headers)
|
||||
else:
|
||||
return False, f"Head request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting headers from {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def make_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Make a generic HTTP request and return JSON response
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: Request URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
**kwargs: Additional arguments for aiohttp request
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[make_request] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Add proxy to kwargs if not already present
|
||||
if 'proxy' not in kwargs:
|
||||
kwargs['proxy'] = self.proxy_url
|
||||
|
||||
async with session.request(method, url, headers=headers, **kwargs) as response:
|
||||
if response.status == 200:
|
||||
# Try to parse as JSON, fall back to text
|
||||
try:
|
||||
data = await response.json()
|
||||
return True, data
|
||||
except:
|
||||
text = await response.text()
|
||||
return True, text
|
||||
elif response.status == 401:
|
||||
return False, "Unauthorized access - invalid or missing API key"
|
||||
elif response.status == 403:
|
||||
return False, "Access forbidden"
|
||||
elif response.status == 404:
|
||||
return False, "Resource not found"
|
||||
else:
|
||||
return False, f"Request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making {method} request to {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None
|
||||
logger.debug("Closed HTTP session")
|
||||
|
||||
async def refresh_session(self):
|
||||
"""Force refresh the HTTP session (useful when proxy settings change)"""
|
||||
await self._create_session()
|
||||
logger.info("HTTP session refreshed due to settings change")
|
||||
|
||||
|
||||
# Global instance accessor
|
||||
async def get_downloader() -> Downloader:
|
||||
"""Get the global downloader instance"""
|
||||
return await Downloader.get_instance()
|
||||
26
py/services/embedding_scanner.py
Normal file
26
py/services/embedding_scanner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingScanner(ModelScanner):
|
||||
"""Service for scanning and managing embedding files"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
super().__init__(
|
||||
model_type="embedding",
|
||||
model_class=EmbeddingMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get embedding root directories"""
|
||||
return config.embeddings_roots
|
||||
49
py/services/embedding_service.py
Normal file
49
py/services/embedding_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingService(BaseModelService):
|
||||
"""Embedding-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Embedding service
|
||||
|
||||
Args:
|
||||
scanner: Embedding scanner instance
|
||||
"""
|
||||
super().__init__("embedding", scanner, EmbeddingMetadata)
|
||||
|
||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||
"""Format Embedding data for API response"""
|
||||
return {
|
||||
"model_name": embedding_data["model_name"],
|
||||
"file_name": embedding_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
||||
"base_model": embedding_data.get("base_model", ""),
|
||||
"folder": embedding_data["folder"],
|
||||
"sha256": embedding_data.get("sha256", ""),
|
||||
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": embedding_data.get("size", 0),
|
||||
"modified": embedding_data.get("modified", ""),
|
||||
"tags": embedding_data.get("tags", []),
|
||||
"from_civitai": embedding_data.get("from_civitai", True),
|
||||
"notes": embedding_data.get("notes", ""),
|
||||
"model_type": embedding_data.get("model_type", "embedding"),
|
||||
"favorite": embedding_data.get("favorite", False),
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Embeddings with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Embeddings with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
246
py/services/example_images_cleanup_service.py
Normal file
246
py/services/example_images_cleanup_service.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Service for cleaning up example image folders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CleanupResult:
|
||||
"""Structured result returned from cleanup operations."""
|
||||
|
||||
success: bool
|
||||
checked_folders: int
|
||||
moved_empty_folders: int
|
||||
moved_orphaned_folders: int
|
||||
skipped_non_hash: int
|
||||
move_failures: int
|
||||
errors: List[str]
|
||||
deleted_root: str | None
|
||||
partial_success: bool
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
"""Convert the dataclass to a serialisable dictionary."""
|
||||
|
||||
data = {
|
||||
"success": self.success,
|
||||
"checked_folders": self.checked_folders,
|
||||
"moved_empty_folders": self.moved_empty_folders,
|
||||
"moved_orphaned_folders": self.moved_orphaned_folders,
|
||||
"moved_total": self.moved_empty_folders + self.moved_orphaned_folders,
|
||||
"skipped_non_hash": self.skipped_non_hash,
|
||||
"move_failures": self.move_failures,
|
||||
"errors": self.errors,
|
||||
"deleted_root": self.deleted_root,
|
||||
"partial_success": self.partial_success,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ExampleImagesCleanupService:
|
||||
"""Encapsulates logic for cleaning example image folders."""
|
||||
|
||||
DELETED_FOLDER_NAME = "_deleted"
|
||||
|
||||
def __init__(self, deleted_folder_name: str | None = None) -> None:
|
||||
self._deleted_folder_name = deleted_folder_name or self.DELETED_FOLDER_NAME
|
||||
|
||||
async def cleanup_example_image_folders(self) -> Dict[str, object]:
|
||||
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
|
||||
|
||||
example_images_path = settings.get("example_images_path")
|
||||
if not example_images_path:
|
||||
logger.debug("Cleanup skipped: example images path not configured")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Example images path is not configured.",
|
||||
"error_code": "path_not_configured",
|
||||
}
|
||||
|
||||
example_root = Path(example_images_path)
|
||||
if not example_root.exists():
|
||||
logger.debug("Cleanup skipped: example images path missing -> %s", example_root)
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Example images path does not exist.",
|
||||
"error_code": "path_not_found",
|
||||
}
|
||||
|
||||
try:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.error("Failed to acquire scanners for cleanup: %s", exc, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to load model scanners: {exc}",
|
||||
"error_code": "scanner_initialization_failed",
|
||||
}
|
||||
|
||||
deleted_bucket = example_root / self._deleted_folder_name
|
||||
deleted_bucket.mkdir(exist_ok=True)
|
||||
|
||||
checked_folders = 0
|
||||
moved_empty = 0
|
||||
moved_orphaned = 0
|
||||
skipped_non_hash = 0
|
||||
move_failures = 0
|
||||
errors: List[str] = []
|
||||
|
||||
for entry in os.scandir(example_root):
|
||||
if not entry.is_dir(follow_symlinks=False):
|
||||
continue
|
||||
|
||||
if entry.name == self._deleted_folder_name:
|
||||
continue
|
||||
|
||||
checked_folders += 1
|
||||
folder_path = Path(entry.path)
|
||||
|
||||
try:
|
||||
if self._is_folder_empty(folder_path):
|
||||
if await self._remove_empty_folder(folder_path):
|
||||
moved_empty += 1
|
||||
else:
|
||||
move_failures += 1
|
||||
continue
|
||||
|
||||
if not self._is_hash_folder(entry.name):
|
||||
skipped_non_hash += 1
|
||||
continue
|
||||
|
||||
hash_exists = (
|
||||
lora_scanner.has_hash(entry.name)
|
||||
or checkpoint_scanner.has_hash(entry.name)
|
||||
or embedding_scanner.has_hash(entry.name)
|
||||
)
|
||||
|
||||
if not hash_exists:
|
||||
if await self._move_folder(folder_path, deleted_bucket):
|
||||
moved_orphaned += 1
|
||||
else:
|
||||
move_failures += 1
|
||||
|
||||
except Exception as exc: # pragma: no cover - filesystem guard
|
||||
move_failures += 1
|
||||
error_message = f"{entry.name}: {exc}"
|
||||
errors.append(error_message)
|
||||
logger.error("Error processing example images folder %s: %s", folder_path, exc, exc_info=True)
|
||||
|
||||
partial_success = move_failures > 0 and (moved_empty > 0 or moved_orphaned > 0)
|
||||
success = move_failures == 0 and not errors
|
||||
|
||||
result = CleanupResult(
|
||||
success=success,
|
||||
checked_folders=checked_folders,
|
||||
moved_empty_folders=moved_empty,
|
||||
moved_orphaned_folders=moved_orphaned,
|
||||
skipped_non_hash=skipped_non_hash,
|
||||
move_failures=move_failures,
|
||||
errors=errors,
|
||||
deleted_root=str(deleted_bucket),
|
||||
partial_success=partial_success,
|
||||
)
|
||||
|
||||
summary = result.to_dict()
|
||||
if success:
|
||||
logger.info(
|
||||
"Example images cleanup complete: checked=%s, moved_empty=%s, moved_orphaned=%s",
|
||||
checked_folders,
|
||||
moved_empty,
|
||||
moved_orphaned,
|
||||
)
|
||||
elif partial_success:
|
||||
logger.warning(
|
||||
"Example images cleanup partially complete: moved=%s, failures=%s",
|
||||
summary["moved_total"],
|
||||
move_failures,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Example images cleanup failed: move_failures=%s, errors=%s",
|
||||
move_failures,
|
||||
errors,
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def _is_folder_empty(folder_path: Path) -> bool:
|
||||
try:
|
||||
with os.scandir(folder_path) as iterator:
|
||||
return not any(iterator)
|
||||
except FileNotFoundError:
|
||||
return True
|
||||
except OSError as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to inspect folder %s: %s", folder_path, exc)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_hash_folder(name: str) -> bool:
|
||||
if len(name) != 64:
|
||||
return False
|
||||
hex_chars = set("0123456789abcdefABCDEF")
|
||||
return all(char in hex_chars for char in name)
|
||||
|
||||
async def _remove_empty_folder(self, folder_path: Path) -> bool:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
shutil.rmtree,
|
||||
str(folder_path),
|
||||
)
|
||||
logger.debug("Removed empty example images folder %s", folder_path)
|
||||
return True
|
||||
except Exception as exc: # pragma: no cover - filesystem guard
|
||||
logger.error("Failed to remove empty example images folder %s: %s", folder_path, exc, exc_info=True)
|
||||
return False
|
||||
|
||||
async def _move_folder(self, folder_path: Path, deleted_bucket: Path) -> bool:
|
||||
destination = self._build_destination(folder_path.name, deleted_bucket)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
shutil.move,
|
||||
str(folder_path),
|
||||
str(destination),
|
||||
)
|
||||
logger.debug("Moved example images folder %s -> %s", folder_path, destination)
|
||||
return True
|
||||
except Exception as exc: # pragma: no cover - filesystem guard
|
||||
logger.error(
|
||||
"Failed to move example images folder %s to %s: %s",
|
||||
folder_path,
|
||||
destination,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
def _build_destination(self, folder_name: str, deleted_bucket: Path) -> Path:
|
||||
destination = deleted_bucket / folder_name
|
||||
suffix = 1
|
||||
|
||||
while destination.exists():
|
||||
destination = deleted_bucket / f"{folder_name}_{suffix}"
|
||||
suffix += 1
|
||||
|
||||
return destination
|
||||
@@ -1,204 +0,0 @@
|
||||
from operator import itemgetter
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
|
||||
from typing import List
|
||||
from threading import Lock
|
||||
from .lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraFileHandler(FileSystemEventHandler):
|
||||
"""Handler for LoRA file system events"""
|
||||
|
||||
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
|
||||
self.scanner = scanner
|
||||
self.loop = loop # 存储事件循环引用
|
||||
self.pending_changes = set() # 待处理的变更
|
||||
self.lock = Lock() # 线程安全锁
|
||||
self.update_task = None # 异步更新任务
|
||||
self._ignore_paths = set() # Add ignore paths set
|
||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||
|
||||
def _should_ignore(self, path: str) -> bool:
|
||||
"""Check if path should be ignored"""
|
||||
real_path = os.path.realpath(path) # Resolve any symbolic links
|
||||
return real_path.replace(os.sep, '/') in self._ignore_paths
|
||||
|
||||
def add_ignore_path(self, path: str, file_size: int = 0):
|
||||
"""Add path to ignore list with dynamic timeout based on file size"""
|
||||
real_path = os.path.realpath(path) # Resolve any symbolic links
|
||||
self._ignore_paths.add(real_path.replace(os.sep, '/'))
|
||||
|
||||
# Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
|
||||
timeout = 5
|
||||
|
||||
asyncio.get_event_loop().call_later(
|
||||
timeout,
|
||||
self._ignore_paths.discard,
|
||||
real_path.replace(os.sep, '/')
|
||||
)
|
||||
|
||||
def on_created(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||
return
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
logger.info(f"LoRA file created: {event.src_path}")
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
def on_deleted(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||
return
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
logger.info(f"LoRA file deleted: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
|
||||
"""Schedule a cache update"""
|
||||
with self.lock:
|
||||
# 使用 config 中的方法映射路径
|
||||
mapped_path = config.map_path_to_link(file_path)
|
||||
normalized_path = mapped_path.replace(os.sep, '/')
|
||||
self.pending_changes.add((action, normalized_path))
|
||||
|
||||
self.loop.call_soon_threadsafe(self._create_update_task)
|
||||
|
||||
def _create_update_task(self):
|
||||
"""Create update task in the event loop"""
|
||||
if self.update_task is None or self.update_task.done():
|
||||
self.update_task = asyncio.create_task(self._process_changes())
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
with self.lock:
|
||||
changes = self.pending_changes.copy()
|
||||
self.pending_changes.clear()
|
||||
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} file changes")
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# Scan new file
|
||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||
if lora_data:
|
||||
# Update tags count
|
||||
for tag in lora_data.get('tags', []):
|
||||
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(lora_data)
|
||||
new_folders.add(lora_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in lora_data:
|
||||
self.scanner._hash_index.add_entry(
|
||||
lora_data['sha256'],
|
||||
lora_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the lora to remove so we can update tags count
|
||||
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if lora_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in lora_to_remove.get('tags', []):
|
||||
if tag in self.scanner._tags_count:
|
||||
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
|
||||
if self.scanner._tags_count[tag] == 0:
|
||||
del self.scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from cache")
|
||||
self.scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {action} for {file_path}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Update folder list
|
||||
all_folders = set(cache.folders) | new_folders
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes: {e}")
|
||||
|
||||
|
||||
class LoraFileMonitor:
|
||||
"""Monitor for LoRA file changes"""
|
||||
|
||||
def __init__(self, scanner: LoraScanner, roots: List[str]):
|
||||
self.scanner = scanner
|
||||
scanner.set_file_monitor(self)
|
||||
self.observer = Observer()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.handler = LoraFileHandler(scanner, self.loop)
|
||||
|
||||
# 使用已存在的路径映射
|
||||
self.monitor_paths = set()
|
||||
for root in roots:
|
||||
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
|
||||
|
||||
# 添加所有已映射的目标路径
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
def start(self):
|
||||
"""Start monitoring"""
|
||||
for path_info in self.monitor_paths:
|
||||
try:
|
||||
if isinstance(path_info, tuple):
|
||||
# 对于链接,监控目标路径
|
||||
_, target_path = path_info
|
||||
self.observer.schedule(self.handler, target_path, recursive=True)
|
||||
logger.info(f"Started monitoring target path: {target_path}")
|
||||
else:
|
||||
# 对于普通路径,直接监控
|
||||
self.observer.schedule(self.handler, path_info, recursive=True)
|
||||
logger.info(f"Started monitoring: {path_info}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring {path_info}: {e}")
|
||||
|
||||
self.observer.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
def rescan_links(self):
|
||||
"""重新扫描链接(当添加新的链接时调用)"""
|
||||
new_paths = set()
|
||||
for path in self.monitor_paths.copy():
|
||||
self._add_link_targets(path)
|
||||
|
||||
# 添加新发现的路径到监控
|
||||
new_paths = self.monitor_paths - set(self.observer.watches.keys())
|
||||
for path in new_paths:
|
||||
try:
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
logger.info(f"Added new monitoring path: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding new monitor for {path}: {e}")
|
||||
@@ -1,64 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
|
||||
@dataclass
|
||||
class LoraCache:
|
||||
"""Cache structure for LoRA data"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = sorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview_url for a specific lora in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the lora wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
else:
|
||||
return False # Lora not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class LoraHashIndex:
|
||||
"""Index for mapping LoRA file hashes to their file paths"""
|
||||
|
||||
def __init__(self):
|
||||
self._hash_to_path: Dict[str, str] = {}
|
||||
|
||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||
"""Add or update a hash -> path mapping"""
|
||||
if not sha256 or not file_path:
|
||||
return
|
||||
self._hash_to_path[sha256] = file_path
|
||||
|
||||
def remove_entry(self, sha256: str) -> None:
|
||||
"""Remove a hash entry"""
|
||||
self._hash_to_path.pop(sha256, None)
|
||||
|
||||
def remove_by_path(self, file_path: str) -> None:
|
||||
"""Remove entry by file path"""
|
||||
for sha256, path in list(self._hash_to_path.items()):
|
||||
if path == file_path:
|
||||
del self._hash_to_path[sha256]
|
||||
break
|
||||
|
||||
def get_path(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a given hash"""
|
||||
return self._hash_to_path.get(sha256)
|
||||
|
||||
def get_hash(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a given file path"""
|
||||
for sha256, path in self._hash_to_path.items():
|
||||
if path == file_path:
|
||||
return sha256
|
||||
return None
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if hash exists in index"""
|
||||
return sha256 in self._hash_to_path
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all entries"""
|
||||
self._hash_to_path.clear()
|
||||
@@ -1,656 +1,65 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import shutil
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from ..utils.file_utils import load_metadata, get_file_info
|
||||
from .lora_cache import LoraCache
|
||||
from difflib import SequenceMatcher
|
||||
from .lora_hash_index import LoraHashIndex
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraScanner:
|
||||
class LoraScanner(ModelScanner):
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# 确保初始化只执行一次
|
||||
if not hasattr(self, '_initialized'):
|
||||
self._cache: Optional[LoraCache] = None
|
||||
self._hash_index = LoraHashIndex()
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
self._initialization_task: Optional[asyncio.Task] = None
|
||||
self._initialized = True
|
||||
self.file_monitor = None # Add this line
|
||||
self._tags_count = {} # Add a dictionary to store tag counts
|
||||
|
||||
def set_file_monitor(self, monitor):
|
||||
"""Set file monitor instance"""
|
||||
self.file_monitor = monitor
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
||||
"""Get cached LoRA data, refresh if needed"""
|
||||
async with self._initialization_lock:
|
||||
|
||||
# 如果缓存未初始化但需要响应请求,返回空缓存
|
||||
if self._cache is None and not force_refresh:
|
||||
return LoraCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# 如果正在初始化,等待完成
|
||||
if self._initialization_task and not self._initialization_task.done():
|
||||
try:
|
||||
await self._initialization_task
|
||||
except Exception as e:
|
||||
logger.error(f"Cache initialization failed: {e}")
|
||||
self._initialization_task = None
|
||||
|
||||
if (self._cache is None or force_refresh):
|
||||
|
||||
# 创建新的初始化任务
|
||||
if not self._initialization_task or self._initialization_task.done():
|
||||
self._initialization_task = asyncio.create_task(self._initialize_cache())
|
||||
|
||||
try:
|
||||
await self._initialization_task
|
||||
except Exception as e:
|
||||
logger.error(f"Cache initialization failed: {e}")
|
||||
# 如果缓存已存在,继续使用旧缓存
|
||||
if self._cache is None:
|
||||
raise # 如果没有缓存,则抛出异常
|
||||
|
||||
return self._cache
|
||||
|
||||
async def _initialize_cache(self) -> None:
|
||||
"""Initialize or refresh the cache"""
|
||||
try:
|
||||
# Clear existing hash index
|
||||
self._hash_index.clear()
|
||||
|
||||
# Clear existing tags count
|
||||
self._tags_count = {}
|
||||
|
||||
# Scan for new data
|
||||
raw_data = await self.scan_all_loras()
|
||||
|
||||
# Build hash index and tags count
|
||||
for lora_data in raw_data:
|
||||
if 'sha256' in lora_data and 'file_path' in lora_data:
|
||||
self._hash_index.add_entry(lora_data['sha256'], lora_data['file_path'])
|
||||
|
||||
# Count tags
|
||||
if 'tags' in lora_data and lora_data['tags']:
|
||||
for tag in lora_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Update cache
|
||||
self._cache = LoraCache(
|
||||
raw_data=raw_data,
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Call resort_cache to create sorted views
|
||||
await self._cache.resort()
|
||||
|
||||
self._initialization_task = None
|
||||
logger.info("LoRA Manager: Cache initialization completed")
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error initializing cache: {e}")
|
||||
self._cache = LoraCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
def fuzzy_match(self, text: str, pattern: str, threshold: float = 0.7) -> bool:
|
||||
"""
|
||||
Check if text matches pattern using fuzzy matching.
|
||||
Returns True if similarity ratio is above threshold.
|
||||
"""
|
||||
if not pattern or not text:
|
||||
return False
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Convert both to lowercase for case-insensitive matching
|
||||
text = text.lower()
|
||||
pattern = pattern.lower()
|
||||
|
||||
# Split pattern into words
|
||||
search_words = pattern.split()
|
||||
|
||||
# Check each word
|
||||
for word in search_words:
|
||||
# First check if word is a substring (faster)
|
||||
if word in text:
|
||||
continue
|
||||
|
||||
# If not found as substring, try fuzzy matching
|
||||
# Check if any part of the text matches this word
|
||||
found_match = False
|
||||
for text_part in text.split():
|
||||
ratio = SequenceMatcher(None, text_part, word).ratio()
|
||||
if ratio >= threshold:
|
||||
found_match = True
|
||||
break
|
||||
|
||||
if not found_match:
|
||||
return False
|
||||
|
||||
# All words found either as substrings or fuzzy matches
|
||||
return True
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy: bool = False,
|
||||
recursive: bool = False, base_models: list = None, tags: list = None,
|
||||
search_options: dict = None) -> Dict:
|
||||
"""Get paginated and filtered lora data
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort method ('name' or 'date')
|
||||
folder: Filter by folder path
|
||||
search: Search term
|
||||
fuzzy: Use fuzzy matching for search
|
||||
recursive: Include subfolders when folder filter is applied
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Dictionary with search options (filename, modelname, tags)
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if recursive:
|
||||
# Recursive mode: match all paths starting with this folder
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item['folder'].startswith(folder + '/') or item['folder'] == folder
|
||||
]
|
||||
else:
|
||||
# Non-recursive mode: match exact folder
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
if fuzzy:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if self._fuzzy_search_match(item, search, search_options)
|
||||
]
|
||||
else:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if self._exact_search_match(item, search, search_options)
|
||||
]
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _fuzzy_search_match(self, item: Dict, search: str, search_options: Dict) -> bool:
|
||||
"""Check if an item matches the search term using fuzzy matching with search options"""
|
||||
# Check filename if enabled
|
||||
if search_options.get('filename', True) and self.fuzzy_match(item.get('file_name', ''), search):
|
||||
return True
|
||||
|
||||
# Check model name if enabled
|
||||
if search_options.get('modelname', True) and self.fuzzy_match(item.get('model_name', ''), search):
|
||||
return True
|
||||
|
||||
# Check tags if enabled
|
||||
if search_options.get('tags', False) and item.get('tags'):
|
||||
for tag in item['tags']:
|
||||
if self.fuzzy_match(tag, search):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _exact_search_match(self, item: Dict, search: str, search_options: Dict) -> bool:
|
||||
"""Check if an item matches the search term using exact matching with search options"""
|
||||
search = search.lower()
|
||||
|
||||
# Check filename if enabled
|
||||
if search_options.get('filename', True) and search in item.get('file_name', '').lower():
|
||||
return True
|
||||
|
||||
# Check model name if enabled
|
||||
if search_options.get('modelname', True) and search in item.get('model_name', '').lower():
|
||||
return True
|
||||
|
||||
# Check tags if enabled
|
||||
if search_options.get('tags', False) and item.get('tags'):
|
||||
for tag in item['tags']:
|
||||
if search in tag.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def invalidate_cache(self):
|
||||
"""Invalidate the current cache"""
|
||||
self._cache = None
|
||||
|
||||
async def scan_all_loras(self) -> List[Dict]:
|
||||
"""Scan all LoRA directories and return metadata"""
|
||||
all_loras = []
|
||||
|
||||
# 分目录异步扫描
|
||||
scan_tasks = []
|
||||
for loras_root in config.loras_roots:
|
||||
task = asyncio.create_task(self._scan_directory(loras_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
loras = await task
|
||||
all_loras.extend(loras)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_loras
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for LoRA files"""
|
||||
loras = []
|
||||
original_root = root_path # 保存原始根路径
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""递归扫描目录,避免循环链接"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
|
||||
# 使用原始路径而不是真实路径
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, loras)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# 对于目录,使用原始路径继续扫描
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return loras
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||
"""处理单个文件并添加到结果列表"""
|
||||
try:
|
||||
result = await self._process_lora_file(file_path, root_path)
|
||||
if result:
|
||||
loras.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
|
||||
"""Process a single LoRA file and return its metadata"""
|
||||
# Try loading existing metadata
|
||||
metadata = await load_metadata(file_path)
|
||||
|
||||
if metadata is None:
|
||||
# Create new metadata if none exists
|
||||
metadata = await get_file_info(file_path)
|
||||
|
||||
# Convert to dict and add folder info
|
||||
lora_data = metadata.to_dict()
|
||||
# Try to fetch missing metadata from Civitai if needed
|
||||
await self._fetch_missing_metadata(file_path, lora_data)
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
lora_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
return lora_data
|
||||
|
||||
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
|
||||
"""Fetch missing description and tags from Civitai if needed
|
||||
|
||||
Args:
|
||||
file_path: Path to the lora file
|
||||
lora_data: Lora metadata dictionary to update
|
||||
"""
|
||||
try:
|
||||
# Check if we need to fetch additional metadata from Civitai
|
||||
needs_metadata_update = False
|
||||
model_id = None
|
||||
|
||||
# Check if we have Civitai model ID but missing metadata
|
||||
if lora_data.get('civitai'):
|
||||
# Try to get model ID directly from the correct location
|
||||
model_id = lora_data['civitai'].get('modelId')
|
||||
|
||||
if model_id:
|
||||
model_id = str(model_id)
|
||||
# Check if tags are missing or empty
|
||||
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
|
||||
|
||||
# Check if description is missing or empty
|
||||
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
|
||||
|
||||
needs_metadata_update = tags_missing or desc_missing
|
||||
|
||||
# Fetch missing metadata if needed
|
||||
if needs_metadata_update and model_id:
|
||||
logger.info(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
client = CivitaiClient()
|
||||
model_metadata = await client.get_model_metadata(model_id)
|
||||
await client.close()
|
||||
|
||||
if (model_metadata):
|
||||
logger.info(f"Updating metadata for {file_path} with model ID {model_id}")
|
||||
|
||||
# Update tags if they were missing
|
||||
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
|
||||
lora_data['tags'] = model_metadata['tags']
|
||||
|
||||
# Update description if it was missing
|
||||
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
|
||||
lora_data['modelDescription'] = model_metadata['description']
|
||||
|
||||
# Save the updated metadata back to file
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(lora_data, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
|
||||
|
||||
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview URL in cache for a specific lora
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
|
||||
"""
|
||||
if self._cache is None:
|
||||
return False
|
||||
|
||||
return await self._cache.update_preview_url(file_path, preview_url)
|
||||
|
||||
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
|
||||
"""Scan a single LoRA file and return its metadata"""
|
||||
try:
|
||||
if not os.path.exists(os.path.realpath(file_path)):
|
||||
return None
|
||||
|
||||
# 获取基本文件信息
|
||||
metadata = await get_file_info(file_path)
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
folder = self._calculate_folder(file_path)
|
||||
|
||||
# 确保 folder 字段存在
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = folder or ''
|
||||
|
||||
return metadata_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_folder(self, file_path: str) -> str:
|
||||
"""Calculate the folder path for a LoRA file"""
|
||||
# 使用原始路径计算相对路径
|
||||
for root in config.loras_roots:
|
||||
if file_path.startswith(root):
|
||||
rel_path = os.path.relpath(file_path, root)
|
||||
return os.path.dirname(rel_path).replace(os.path.sep, '/')
|
||||
return ''
|
||||
|
||||
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||
"""Move a model and its associated files to a new location"""
|
||||
try:
|
||||
# 保持原始路径格式
|
||||
source_path = source_path.replace(os.sep, '/')
|
||||
target_path = target_path.replace(os.sep, '/')
|
||||
|
||||
# 其余代码保持不变
|
||||
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||
source_dir = os.path.dirname(source_path)
|
||||
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
|
||||
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
|
||||
|
||||
# 使用真实路径进行文件操作
|
||||
real_source = os.path.realpath(source_path)
|
||||
real_target = os.path.realpath(target_lora)
|
||||
|
||||
file_size = os.path.getsize(real_source)
|
||||
|
||||
if self.file_monitor:
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
real_source,
|
||||
file_size
|
||||
)
|
||||
self.file_monitor.handler.add_ignore_path(
|
||||
real_target,
|
||||
file_size
|
||||
)
|
||||
|
||||
# 使用真实路径进行文件操作
|
||||
shutil.move(real_source, real_target)
|
||||
|
||||
# Move associated files
|
||||
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
|
||||
if os.path.exists(source_metadata):
|
||||
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
|
||||
shutil.move(source_metadata, target_metadata)
|
||||
metadata = await self._update_metadata_paths(target_metadata, target_lora)
|
||||
|
||||
# Move preview file if exists
|
||||
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
||||
'.png', '.jpeg', '.jpg', '.mp4']
|
||||
for ext in preview_extensions:
|
||||
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
|
||||
if os.path.exists(source_preview):
|
||||
target_preview = os.path.join(target_path, f"{base_name}{ext}")
|
||||
shutil.move(source_preview, target_preview)
|
||||
break
|
||||
|
||||
# Update cache
|
||||
await self.update_single_lora_cache(source_path, target_lora, metadata)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Find the existing item to remove its tags from count
|
||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||
if existing_item and 'tags' in existing_item:
|
||||
for tag in existing_item.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
|
||||
if self._tags_count[tag] == 0:
|
||||
del self._tags_count[tag]
|
||||
|
||||
# Remove old path from hash index if exists
|
||||
self._hash_index.remove_by_path(original_path)
|
||||
|
||||
# Remove the old entry from raw_data
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != original_path
|
||||
]
|
||||
|
||||
if metadata:
|
||||
# If this is an update to an existing path (not a move), ensure folder is preserved
|
||||
if original_path == new_path:
|
||||
# Find the folder from existing entries or calculate it
|
||||
existing_folder = next((item['folder'] for item in cache.raw_data
|
||||
if item['file_path'] == original_path), None)
|
||||
if existing_folder:
|
||||
metadata['folder'] = existing_folder
|
||||
else:
|
||||
metadata['folder'] = self._calculate_folder(new_path)
|
||||
else:
|
||||
# For moved files, recalculate the folder
|
||||
metadata['folder'] = self._calculate_folder(new_path)
|
||||
|
||||
# Add the updated metadata to raw_data
|
||||
cache.raw_data.append(metadata)
|
||||
|
||||
# Update hash index with new path
|
||||
if 'sha256' in metadata:
|
||||
self._hash_index.add_entry(metadata['sha256'], new_path)
|
||||
|
||||
# Update folders list
|
||||
all_folders = set(item['folder'] for item in cache.raw_data)
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Update tags count with the new/updated tags
|
||||
if 'tags' in metadata:
|
||||
for tag in metadata.get('tags', []):
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Resort cache
|
||||
await cache.resort()
|
||||
|
||||
return True
|
||||
|
||||
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
|
||||
"""Update file paths in metadata file"""
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Update file_path
|
||||
metadata['file_path'] = lora_path.replace(os.sep, '/')
|
||||
|
||||
# Update preview_url if exists
|
||||
if 'preview_url' in metadata:
|
||||
preview_dir = os.path.dirname(lora_path)
|
||||
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
||||
preview_ext = os.path.splitext(metadata['preview_url'])[1]
|
||||
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
||||
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
|
||||
|
||||
# Add new methods for hash index functionality
|
||||
def has_lora_hash(self, sha256: str) -> bool:
|
||||
"""Check if a LoRA with given hash exists"""
|
||||
return self._hash_index.has_hash(sha256)
|
||||
|
||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a LoRA by its hash"""
|
||||
return self._hash_index.get_path(sha256)
|
||||
|
||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self._hash_index.get_hash(file_path)
|
||||
|
||||
# Add new method to get top tags
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count
|
||||
|
||||
Args:
|
||||
limit: Maximum number of tags to return
|
||||
|
||||
Returns:
|
||||
List of dictionaries with tag name and count, sorted by count
|
||||
"""
|
||||
# Make sure cache is initialized
|
||||
await self.get_cached_data()
|
||||
|
||||
# Sort tags by count in descending order
|
||||
sorted_tags = sorted(
|
||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||
key=lambda x: x['count'],
|
||||
reverse=True
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
|
||||
# Return limited number
|
||||
return sorted_tags[:limit]
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get lora root directories"""
|
||||
return config.loras_roots
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
||||
|
||||
# First check if the hash index has any entries
|
||||
if hasattr(self, '_hash_index'):
|
||||
index_entries = len(self._hash_index._hash_to_path)
|
||||
print(f"Hash index has {index_entries} entries", file=sys.stderr)
|
||||
|
||||
# Print a few example entries if available
|
||||
if index_entries > 0:
|
||||
print("\nSample hash index entries:", file=sys.stderr)
|
||||
count = 0
|
||||
for hash_val, path in self._hash_index._hash_to_path.items():
|
||||
if count < 5: # Just show the first 5
|
||||
print(f"Hash: {hash_val[:8]}... -> Path: {path}", file=sys.stderr)
|
||||
count += 1
|
||||
else:
|
||||
break
|
||||
else:
|
||||
print("Hash index not initialized", file=sys.stderr)
|
||||
|
||||
# Try looking up by a known hash for testing
|
||||
if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path:
|
||||
print("No hash entries to test lookup with", file=sys.stderr)
|
||||
return
|
||||
|
||||
test_hash = next(iter(self._hash_index._hash_to_path.keys()))
|
||||
test_path = self._hash_index.get_path(test_hash)
|
||||
print(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}", file=sys.stderr)
|
||||
|
||||
# Also test reverse lookup
|
||||
test_hash_result = self._hash_index.get_hash(test_path)
|
||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||
|
||||
|
||||
181
py/services/lora_service.py
Normal file
181
py/services/lora_service.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize LoRA service
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
"from_civitai": lora_data.get("from_civitai", True),
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
data = cache.raw_data
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
file_path = lora.get('file_path', '')
|
||||
if file_path:
|
||||
# Convert to forward slashes and extract relative path
|
||||
file_path_normalized = file_path.replace('\\', '/')
|
||||
relative_path = relative_path.replace('\\', '/')
|
||||
# Find the relative path part by looking for the relative_path in the full path
|
||||
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
|
||||
return lora.get('usage_tips', '')
|
||||
|
||||
return None
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
151
py/services/metadata_archive_manager.py
Normal file
151
py/services/metadata_archive_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import zipfile
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from .downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetadataArchiveManager:
|
||||
"""Manages downloading and extracting Civitai metadata archive database"""
|
||||
|
||||
DOWNLOAD_URLS = [
|
||||
"https://github.com/willmiao/civitai-metadata-archive-db/releases/download/db-2025-08-08/civitai.zip",
|
||||
"https://huggingface.co/datasets/willmiao/civitai-metadata-archive-db/blob/main/civitai.zip"
|
||||
]
|
||||
|
||||
def __init__(self, base_path: str):
|
||||
"""Initialize with base path where files will be stored"""
|
||||
self.base_path = Path(base_path)
|
||||
self.civitai_folder = self.base_path / "civitai"
|
||||
self.archive_path = self.base_path / "civitai.zip"
|
||||
self.db_path = self.civitai_folder / "civitai.sqlite"
|
||||
|
||||
def is_database_available(self) -> bool:
|
||||
"""Check if the SQLite database is available and valid"""
|
||||
return self.db_path.exists() and self.db_path.stat().st_size > 0
|
||||
|
||||
def get_database_path(self) -> Optional[str]:
|
||||
"""Get the path to the SQLite database if available"""
|
||||
if self.is_database_available():
|
||||
return str(self.db_path)
|
||||
return None
|
||||
|
||||
async def download_and_extract_database(self, progress_callback=None) -> bool:
|
||||
"""Download and extract the metadata archive database
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback function to report progress
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create directories if they don't exist
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
self.civitai_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download the archive
|
||||
if not await self._download_archive(progress_callback):
|
||||
return False
|
||||
|
||||
# Extract the archive
|
||||
if not await self._extract_archive(progress_callback):
|
||||
return False
|
||||
|
||||
# Clean up the archive file
|
||||
if self.archive_path.exists():
|
||||
self.archive_path.unlink()
|
||||
|
||||
logger.info(f"Successfully downloaded and extracted metadata database to {self.db_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading and extracting metadata database: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _download_archive(self, progress_callback=None) -> bool:
|
||||
"""Download the zip archive from one of the available URLs"""
|
||||
downloader = await get_downloader()
|
||||
|
||||
for url in self.DOWNLOAD_URLS:
|
||||
try:
|
||||
logger.info(f"Attempting to download from {url}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("download", f"Downloading from {url}")
|
||||
|
||||
# Custom progress callback to report download progress
|
||||
async def download_progress(progress):
|
||||
if progress_callback:
|
||||
progress_callback("download", f"Downloading archive... {progress:.1f}%")
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
url=url,
|
||||
save_path=str(self.archive_path),
|
||||
progress_callback=download_progress,
|
||||
use_auth=False, # Public download, no auth needed
|
||||
allow_resume=True
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully downloaded archive from {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to download from {url}: {result}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error downloading from {url}: {e}")
|
||||
continue
|
||||
|
||||
logger.error("Failed to download archive from any URL")
|
||||
return False
|
||||
|
||||
async def _extract_archive(self, progress_callback=None) -> bool:
|
||||
"""Extract the zip archive to the civitai folder"""
|
||||
try:
|
||||
if progress_callback:
|
||||
progress_callback("extract", "Extracting archive...")
|
||||
|
||||
# Run extraction in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._extract_zip_sync)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("extract", "Extraction completed")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting archive: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _extract_zip_sync(self):
|
||||
"""Synchronous zip extraction (runs in thread pool)"""
|
||||
with zipfile.ZipFile(self.archive_path, 'r') as archive:
|
||||
archive.extractall(path=self.base_path)
|
||||
|
||||
async def remove_database(self) -> bool:
|
||||
"""Remove the metadata database and folder"""
|
||||
try:
|
||||
if self.civitai_folder.exists():
|
||||
# Remove all files in the civitai folder
|
||||
for file_path in self.civitai_folder.iterdir():
|
||||
if file_path.is_file():
|
||||
file_path.unlink()
|
||||
|
||||
# Remove the folder itself
|
||||
self.civitai_folder.rmdir()
|
||||
|
||||
# Also remove the archive file if it exists
|
||||
if self.archive_path.exists():
|
||||
self.archive_path.unlink()
|
||||
|
||||
logger.info("Successfully removed metadata database")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing metadata database: {e}", exc_info=True)
|
||||
return False
|
||||
117
py/services/metadata_service.py
Normal file
117
py/services/metadata_service.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import logging
|
||||
from .model_metadata_provider import (
|
||||
ModelMetadataProviderManager,
|
||||
SQLiteModelMetadataProvider,
|
||||
CivitaiModelMetadataProvider,
|
||||
FallbackMetadataProvider
|
||||
)
|
||||
from .settings_manager import settings
|
||||
from .metadata_archive_manager import MetadataArchiveManager
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def initialize_metadata_providers():
|
||||
"""Initialize and configure all metadata providers based on settings"""
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
# Clear existing providers to allow reinitialization
|
||||
provider_manager.providers.clear()
|
||||
provider_manager.default_provider = None
|
||||
|
||||
# Get settings
|
||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
|
||||
providers = []
|
||||
|
||||
# Initialize archive database provider if enabled
|
||||
if enable_archive_db:
|
||||
try:
|
||||
# Initialize archive manager
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
archive_manager = MetadataArchiveManager(base_path)
|
||||
|
||||
db_path = archive_manager.get_database_path()
|
||||
if db_path and os.path.exists(db_path):
|
||||
sqlite_provider = SQLiteModelMetadataProvider(db_path)
|
||||
provider_manager.register_provider('sqlite', sqlite_provider)
|
||||
providers.append(('sqlite', sqlite_provider))
|
||||
logger.info(f"SQLite metadata provider registered with database: {db_path}")
|
||||
else:
|
||||
logger.warning("Metadata archive database is enabled but database file not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SQLite metadata provider: {e}")
|
||||
|
||||
# Initialize Civitai API provider (always available as fallback)
|
||||
try:
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
civitai_provider = CivitaiModelMetadataProvider(civitai_client)
|
||||
provider_manager.register_provider('civitai_api', civitai_provider)
|
||||
providers.append(('civitai_api', civitai_provider))
|
||||
logger.debug("Civitai API metadata provider registered")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
|
||||
|
||||
# Register CivArchive provider, but do NOT add to fallback providers
|
||||
try:
|
||||
from .model_metadata_provider import CivArchiveModelMetadataProvider
|
||||
civarchive_provider = CivArchiveModelMetadataProvider()
|
||||
provider_manager.register_provider('civarchive', civarchive_provider)
|
||||
logger.debug("CivArchive metadata provider registered (not included in fallback)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
|
||||
|
||||
# Set up fallback provider based on available providers
|
||||
if len(providers) > 1:
|
||||
# Always use Civitai API first, then Archive DB
|
||||
ordered_providers = []
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api'])
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite'])
|
||||
|
||||
if ordered_providers:
|
||||
fallback_provider = FallbackMetadataProvider(ordered_providers)
|
||||
provider_manager.register_provider('fallback', fallback_provider, is_default=True)
|
||||
logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, Civitai API first")
|
||||
elif len(providers) == 1:
|
||||
# Only one provider available, set it as default
|
||||
provider_name, provider = providers[0]
|
||||
provider_manager.register_provider(provider_name, provider, is_default=True)
|
||||
logger.debug(f"Single metadata provider registered as default: {provider_name}")
|
||||
else:
|
||||
logger.warning("No metadata providers available - this may cause metadata lookup failures")
|
||||
|
||||
return provider_manager
|
||||
|
||||
async def update_metadata_providers():
|
||||
"""Update metadata providers based on current settings"""
|
||||
try:
|
||||
# Get current settings
|
||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
|
||||
# Reinitialize all providers with new settings
|
||||
provider_manager = await initialize_metadata_providers()
|
||||
|
||||
logger.info(f"Updated metadata providers, archive_db enabled: {enable_archive_db}")
|
||||
return provider_manager
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata providers: {e}")
|
||||
return await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
async def get_metadata_archive_manager():
|
||||
"""Get metadata archive manager instance"""
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
return MetadataArchiveManager(base_path)
|
||||
|
||||
async def get_metadata_provider(provider_name: str = None):
|
||||
"""Get a specific metadata provider or default provider"""
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
if provider_name:
|
||||
return provider_manager._get_provider(provider_name)
|
||||
|
||||
return provider_manager._get_provider()
|
||||
|
||||
async def get_default_metadata_provider():
|
||||
"""Get the default metadata provider (fallback or single provider)"""
|
||||
return await get_metadata_provider()
|
||||
355
py/services/metadata_sync_service.py
Normal file
355
py/services/metadata_sync_service.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Services for synchronising metadata with remote providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.model_utils import determine_base_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataProviderProtocol:
|
||||
"""Subset of metadata provider interface consumed by the sync service."""
|
||||
|
||||
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
...
|
||||
|
||||
async def get_model_version(
|
||||
self, model_id: int, model_version_id: Optional[int]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
|
||||
class MetadataSyncService:
|
||||
"""High level orchestration for metadata synchronisation flows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
preview_service,
|
||||
settings: SettingsManager,
|
||||
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
|
||||
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._preview_service = preview_service
|
||||
self._settings = settings
|
||||
self._get_default_provider = default_metadata_provider_factory
|
||||
self._get_provider = metadata_provider_selector
|
||||
|
||||
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
|
||||
"""Load metadata JSON from disk, returning an empty structure when missing."""
|
||||
|
||||
if not os.path.exists(metadata_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||
return json.load(handle)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
|
||||
return {}
|
||||
|
||||
async def mark_not_found_on_civitai(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist the not-found flag for a metadata payload."""
|
||||
|
||||
local_metadata["from_civitai"] = False
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
|
||||
@staticmethod
|
||||
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
|
||||
"""Determine if the metadata originated from the CivitAI public API."""
|
||||
|
||||
if not isinstance(meta, dict):
|
||||
return False
|
||||
files = meta.get("files")
|
||||
images = meta.get("images")
|
||||
source = meta.get("source")
|
||||
return bool(files) and bool(images) and source != "archive_db"
|
||||
|
||||
async def update_model_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
civitai_metadata: Dict[str, Any],
|
||||
metadata_provider: Optional[MetadataProviderProtocol] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge remote metadata into the local record and persist the result."""
|
||||
|
||||
existing_civitai = local_metadata.get("civitai") or {}
|
||||
|
||||
if (
|
||||
civitai_metadata.get("source") == "archive_db"
|
||||
and self.is_civitai_api_metadata(existing_civitai)
|
||||
):
|
||||
logger.info(
|
||||
"Skip civitai update for %s (%s)",
|
||||
local_metadata.get("model_name", ""),
|
||||
existing_civitai.get("name", ""),
|
||||
)
|
||||
else:
|
||||
merged_civitai = existing_civitai.copy()
|
||||
merged_civitai.update(civitai_metadata)
|
||||
|
||||
if civitai_metadata.get("source") == "archive_db":
|
||||
model_name = civitai_metadata.get("model", {}).get("name", "")
|
||||
version_name = civitai_metadata.get("name", "")
|
||||
logger.info(
|
||||
"Recovered metadata from archive_db for deleted model: %s (%s)",
|
||||
model_name,
|
||||
version_name,
|
||||
)
|
||||
|
||||
if "trainedWords" in existing_civitai:
|
||||
existing_trained = existing_civitai.get("trainedWords", [])
|
||||
new_trained = civitai_metadata.get("trainedWords", [])
|
||||
merged_trained = list(set(existing_trained + new_trained))
|
||||
merged_civitai["trainedWords"] = merged_trained
|
||||
|
||||
local_metadata["civitai"] = merged_civitai
|
||||
|
||||
if "model" in civitai_metadata and civitai_metadata["model"]:
|
||||
model_data = civitai_metadata["model"]
|
||||
|
||||
if model_data.get("name"):
|
||||
local_metadata["model_name"] = model_data["name"]
|
||||
|
||||
if not local_metadata.get("modelDescription") and model_data.get("description"):
|
||||
local_metadata["modelDescription"] = model_data["description"]
|
||||
|
||||
if not local_metadata.get("tags") and model_data.get("tags"):
|
||||
local_metadata["tags"] = model_data["tags"]
|
||||
|
||||
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
|
||||
"creator"
|
||||
):
|
||||
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||
|
||||
local_metadata["base_model"] = determine_base_model(
|
||||
civitai_metadata.get("baseModel")
|
||||
)
|
||||
|
||||
await self._preview_service.ensure_preview_for_metadata(
|
||||
metadata_path, local_metadata, civitai_metadata.get("images", [])
|
||||
)
|
||||
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
return local_metadata
|
||||
|
||||
async def fetch_and_update_model(
|
||||
self,
|
||||
*,
|
||||
sha256: str,
|
||||
file_path: str,
|
||||
model_data: Dict[str, Any],
|
||||
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Fetch metadata for a model and update both disk and cache state."""
|
||||
|
||||
if not isinstance(model_data, dict):
|
||||
error = f"Invalid model_data type: {type(model_data)}"
|
||||
logger.error(error)
|
||||
return False, error
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||
|
||||
try:
|
||||
if model_data.get("civitai_deleted") is True:
|
||||
if not enable_archive or model_data.get("db_checked") is True:
|
||||
return (
|
||||
False,
|
||||
"CivitAI model is deleted and metadata archive DB is not enabled",
|
||||
)
|
||||
metadata_provider = await self._get_provider("sqlite")
|
||||
else:
|
||||
metadata_provider = await self._get_default_provider()
|
||||
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
if error == "Model not found":
|
||||
model_data["from_civitai"] = False
|
||||
model_data["civitai_deleted"] = True
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
data_to_save = model_data.copy()
|
||||
data_to_save.pop("folder", None)
|
||||
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||
|
||||
error_msg = (
|
||||
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
model_data["from_civitai"] = True
|
||||
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
local_metadata = model_data.copy()
|
||||
local_metadata.pop("folder", None)
|
||||
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
local_metadata,
|
||||
civitai_metadata,
|
||||
metadata_provider,
|
||||
)
|
||||
|
||||
update_payload = {
|
||||
"model_name": local_metadata.get("model_name"),
|
||||
"preview_url": local_metadata.get("preview_url"),
|
||||
"civitai": local_metadata.get("civitai"),
|
||||
}
|
||||
model_data.update(update_payload)
|
||||
|
||||
await update_cache_func(file_path, file_path, local_metadata)
|
||||
return True, None
|
||||
except KeyError as exc:
|
||||
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
error_msg = f"Error fetching metadata: {exc}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
async def fetch_metadata_by_sha(
|
||||
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Fetch metadata for a SHA256 hash from the configured provider."""
|
||||
|
||||
provider = metadata_provider or await self._get_default_provider()
|
||||
return await provider.get_model_by_hash(sha256)
|
||||
|
||||
async def relink_metadata(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
model_id: int,
|
||||
model_version_id: Optional[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Relink a local metadata record to a specific CivitAI model version."""
|
||||
|
||||
provider = await self._get_default_provider()
|
||||
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
|
||||
if not civitai_metadata:
|
||||
raise ValueError(
|
||||
f"Model version not found on CivitAI for ID: {model_id}"
|
||||
+ (f" with version: {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
|
||||
primary_model_file: Optional[Dict[str, Any]] = None
|
||||
for file_info in civitai_metadata.get("files", []):
|
||||
if file_info.get("primary", False) and file_info.get("type") == "Model":
|
||||
primary_model_file = file_info
|
||||
break
|
||||
|
||||
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
|
||||
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
metadata,
|
||||
civitai_metadata,
|
||||
provider,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
async def save_metadata_updates(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
updates: Dict[str, Any],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply metadata updates and persist to disk and cache."""
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
|
||||
metadata[key].update(value)
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
await update_cache(file_path, file_path, metadata)
|
||||
|
||||
if "model_name" in updates:
|
||||
logger.debug("Metadata update touched model_name; cache resort required")
|
||||
|
||||
return metadata
|
||||
|
||||
async def verify_duplicate_hashes(
|
||||
self,
|
||||
*,
|
||||
file_paths: Iterable[str],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
hash_calculator: Callable[[str], Awaitable[str]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Verify a collection of files share the same SHA256 hash."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for verification")
|
||||
|
||||
results = {
|
||||
"verified_as_duplicates": True,
|
||||
"mismatched_files": [],
|
||||
"new_hash_map": {},
|
||||
}
|
||||
|
||||
expected_hash: Optional[str] = None
|
||||
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
|
||||
first_metadata = await metadata_loader(first_metadata_path)
|
||||
if first_metadata and "sha256" in first_metadata:
|
||||
expected_hash = first_metadata["sha256"].lower()
|
||||
|
||||
for path in file_paths:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
actual_hash = await hash_calculator(path)
|
||||
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
stored_hash = metadata.get("sha256", "").lower()
|
||||
|
||||
if not expected_hash:
|
||||
expected_hash = stored_hash
|
||||
|
||||
if actual_hash != expected_hash:
|
||||
results["verified_as_duplicates"] = False
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = actual_hash
|
||||
|
||||
if actual_hash != stored_hash:
|
||||
metadata["sha256"] = actual_hash
|
||||
await self._metadata_manager.save_metadata(path, metadata)
|
||||
await update_cache(path, path, metadata)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error verifying hash for %s: %s", path, exc)
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = "error_calculating_hash"
|
||||
results["verified_as_duplicates"] = False
|
||||
|
||||
return results
|
||||
|
||||
104
py/services/model_cache.py
Normal file
104
py/services/model_cache.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
# Supported sort modes: (sort_key, order)
|
||||
# order: 'asc' for ascending, 'desc' for descending
|
||||
SUPPORTED_SORT_MODES = [
|
||||
('name', 'asc'),
|
||||
('name', 'desc'),
|
||||
('date', 'asc'),
|
||||
('date', 'desc'),
|
||||
('size', 'asc'),
|
||||
('size', 'desc'),
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
raw_data: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
self._last_sort: Tuple[str, str] = (None, None)
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
async with self._lock:
|
||||
if self._last_sort != (None, None):
|
||||
sort_key, order = self._last_sort
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
# Update folder list
|
||||
# else: do nothing
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by model_name, case-insensitive
|
||||
return natsorted(
|
||||
data,
|
||||
key=lambda x: x['model_name'].lower(),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
reverse=reverse
|
||||
)
|
||||
else:
|
||||
# Fallback: no sort
|
||||
return list(data)
|
||||
|
||||
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||
async with self._lock:
|
||||
if (sort_key, order) == self._last_sort:
|
||||
return self._last_sorted_data
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sort = (sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
return sorted_data
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||
"""Update preview_url for a specific model in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the model to update
|
||||
preview_url: The new preview URL
|
||||
preview_nsfw_level: The NSFW level of the preview
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the model wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
item['preview_nsfw_level'] = preview_nsfw_level
|
||||
break
|
||||
else:
|
||||
return False # Model not found
|
||||
|
||||
return True
|
||||
463
py/services/model_file_service.py
Normal file
463
py/services/model_file_service.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any, Set
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
|
||||
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressCallback(ABC):
|
||||
"""Abstract callback interface for progress reporting"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Called when progress is updated"""
|
||||
pass
|
||||
|
||||
|
||||
class AutoOrganizeResult:
|
||||
"""Result object for auto-organize operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.total: int = 0
|
||||
self.processed: int = 0
|
||||
self.success_count: int = 0
|
||||
self.failure_count: int = 0
|
||||
self.skipped_count: int = 0
|
||||
self.operation_type: str = 'unknown'
|
||||
self.cleanup_counts: Dict[str, int] = {}
|
||||
self.results: List[Dict[str, Any]] = []
|
||||
self.results_truncated: bool = False
|
||||
self.sample_results: List[Dict[str, Any]] = []
|
||||
self.is_flat_structure: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert result to dictionary"""
|
||||
result = {
|
||||
'success': True,
|
||||
'message': f'Auto-organize {self.operation_type} completed: {self.success_count} moved, {self.skipped_count} skipped, {self.failure_count} failed out of {self.total} total',
|
||||
'summary': {
|
||||
'total': self.total,
|
||||
'success': self.success_count,
|
||||
'skipped': self.skipped_count,
|
||||
'failures': self.failure_count,
|
||||
'organization_type': 'flat' if self.is_flat_structure else 'structured',
|
||||
'cleaned_dirs': self.cleanup_counts,
|
||||
'operation_type': self.operation_type
|
||||
}
|
||||
}
|
||||
|
||||
if self.results_truncated:
|
||||
result['results_truncated'] = True
|
||||
result['sample_results'] = self.sample_results
|
||||
else:
|
||||
result['results'] = self.results
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ModelFileService:
|
||||
"""Service for handling model file operations and organization"""
|
||||
|
||||
def __init__(self, scanner, model_type: str):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
scanner: Model scanner instance
|
||||
model_type: Type of model (e.g., 'lora', 'checkpoint')
|
||||
"""
|
||||
self.scanner = scanner
|
||||
self.model_type = model_type
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
async def auto_organize_models(
|
||||
self,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
progress_callback: Optional[ProgressCallback] = None
|
||||
) -> AutoOrganizeResult:
|
||||
"""Auto-organize models based on current settings
|
||||
|
||||
Args:
|
||||
file_paths: Optional list of specific file paths to organize.
|
||||
If None, organizes all models.
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
AutoOrganizeResult object with operation results
|
||||
"""
|
||||
result = AutoOrganizeResult()
|
||||
source_directories: Set[str] = set()
|
||||
|
||||
try:
|
||||
# Get all models from cache
|
||||
cache = await self.scanner.get_cached_data()
|
||||
all_models = cache.raw_data
|
||||
|
||||
# Filter models if specific file paths are provided
|
||||
if file_paths:
|
||||
all_models = [model for model in all_models if model.get('file_path') in file_paths]
|
||||
result.operation_type = 'bulk'
|
||||
else:
|
||||
result.operation_type = 'all'
|
||||
|
||||
# Get model roots for this scanner
|
||||
model_roots = self.get_model_roots()
|
||||
if not model_roots:
|
||||
raise ValueError('No model roots configured')
|
||||
|
||||
# Check if flat structure is configured for this model type
|
||||
path_template = settings.get_download_path_template(self.model_type)
|
||||
result.is_flat_structure = not path_template
|
||||
|
||||
# Initialize tracking
|
||||
result.total = len(all_models)
|
||||
|
||||
# Send initial progress
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'started',
|
||||
'total': result.total,
|
||||
'processed': 0,
|
||||
'success': 0,
|
||||
'failures': 0,
|
||||
'skipped': 0,
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
# Process models in batches
|
||||
await self._process_models_in_batches(
|
||||
all_models,
|
||||
model_roots,
|
||||
result,
|
||||
progress_callback,
|
||||
source_directories # Pass the set to track source directories
|
||||
)
|
||||
|
||||
# Send cleanup progress
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'cleaning',
|
||||
'total': result.total,
|
||||
'processed': result.processed,
|
||||
'success': result.success_count,
|
||||
'failures': result.failure_count,
|
||||
'skipped': result.skipped_count,
|
||||
'message': 'Cleaning up empty directories...',
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
# Clean up empty directories - only in affected directories for bulk operations
|
||||
cleanup_paths = list(source_directories) if result.operation_type == 'bulk' else model_roots
|
||||
result.cleanup_counts = await self._cleanup_empty_directories(cleanup_paths)
|
||||
|
||||
# Send completion message
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'completed',
|
||||
'total': result.total,
|
||||
'processed': result.processed,
|
||||
'success': result.success_count,
|
||||
'failures': result.failure_count,
|
||||
'skipped': result.skipped_count,
|
||||
'cleanup': result.cleanup_counts,
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto_organize_models: {e}", exc_info=True)
|
||||
|
||||
# Send error message
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'error',
|
||||
'error': str(e),
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
raise e
|
||||
|
||||
async def _process_models_in_batches(
|
||||
self,
|
||||
all_models: List[Dict[str, Any]],
|
||||
model_roots: List[str],
|
||||
result: AutoOrganizeResult,
|
||||
progress_callback: Optional[ProgressCallback],
|
||||
source_directories: Optional[Set[str]] = None
|
||||
) -> None:
|
||||
"""Process models in batches to avoid overwhelming the system"""
|
||||
|
||||
for i in range(0, result.total, AUTO_ORGANIZE_BATCH_SIZE):
|
||||
batch = all_models[i:i + AUTO_ORGANIZE_BATCH_SIZE]
|
||||
|
||||
for model in batch:
|
||||
await self._process_single_model(model, model_roots, result, source_directories)
|
||||
result.processed += 1
|
||||
|
||||
# Send progress update after each batch
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'processing',
|
||||
'total': result.total,
|
||||
'processed': result.processed,
|
||||
'success': result.success_count,
|
||||
'failures': result.failure_count,
|
||||
'skipped': result.skipped_count,
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
# Small delay between batches
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _process_single_model(
|
||||
self,
|
||||
model: Dict[str, Any],
|
||||
model_roots: List[str],
|
||||
result: AutoOrganizeResult,
|
||||
source_directories: Optional[Set[str]] = None
|
||||
) -> None:
|
||||
"""Process a single model for organization"""
|
||||
try:
|
||||
file_path = model.get('file_path')
|
||||
model_name = model.get('model_name', 'Unknown')
|
||||
|
||||
if not file_path:
|
||||
self._add_result(result, model_name, False, "No file path found")
|
||||
result.failure_count += 1
|
||||
return
|
||||
|
||||
# Find which model root this file belongs to
|
||||
current_root = self._find_model_root(file_path, model_roots)
|
||||
if not current_root:
|
||||
self._add_result(result, model_name, False,
|
||||
"Model file not found in any configured root directory")
|
||||
result.failure_count += 1
|
||||
return
|
||||
|
||||
# Determine target directory
|
||||
target_dir = await self._calculate_target_directory(
|
||||
model, current_root, result.is_flat_structure
|
||||
)
|
||||
|
||||
if target_dir is None:
|
||||
self._add_result(result, model_name, False,
|
||||
"Skipped - insufficient metadata for organization")
|
||||
result.skipped_count += 1
|
||||
return
|
||||
|
||||
current_dir = os.path.dirname(file_path)
|
||||
|
||||
# Skip if already in correct location
|
||||
if current_dir.replace(os.sep, '/') == target_dir.replace(os.sep, '/'):
|
||||
result.skipped_count += 1
|
||||
return
|
||||
|
||||
# Check for conflicts
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_dir, file_name)
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
self._add_result(result, model_name, False,
|
||||
f"Target file already exists: {target_file_path}")
|
||||
result.failure_count += 1
|
||||
return
|
||||
|
||||
# Store the source directory for potential cleanup
|
||||
if source_directories is not None:
|
||||
source_directories.add(current_dir)
|
||||
|
||||
# Perform the move
|
||||
success = await self.scanner.move_model(file_path, target_dir)
|
||||
|
||||
if success:
|
||||
result.success_count += 1
|
||||
else:
|
||||
self._add_result(result, model_name, False, "Failed to move model")
|
||||
result.failure_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing model {model.get('model_name', 'Unknown')}: {e}", exc_info=True)
|
||||
self._add_result(result, model.get('model_name', 'Unknown'), False, f"Error: {str(e)}")
|
||||
result.failure_count += 1
|
||||
|
||||
def _find_model_root(self, file_path: str, model_roots: List[str]) -> Optional[str]:
|
||||
"""Find which model root the file belongs to"""
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root).replace(os.sep, '/')
|
||||
normalized_file = os.path.normpath(file_path).replace(os.sep, '/')
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
return root
|
||||
return None
|
||||
|
||||
async def _calculate_target_directory(
|
||||
self,
|
||||
model: Dict[str, Any],
|
||||
current_root: str,
|
||||
is_flat_structure: bool
|
||||
) -> Optional[str]:
|
||||
"""Calculate the target directory for a model"""
|
||||
if is_flat_structure:
|
||||
file_path = model.get('file_path')
|
||||
current_dir = os.path.dirname(file_path)
|
||||
|
||||
# Check if already in root directory
|
||||
if os.path.normpath(current_dir) == os.path.normpath(current_root):
|
||||
return None # Signal to skip
|
||||
|
||||
return current_root
|
||||
else:
|
||||
# Calculate new relative path based on settings
|
||||
new_relative_path = calculate_relative_path_for_model(model, self.model_type)
|
||||
|
||||
if not new_relative_path:
|
||||
return None # Signal to skip
|
||||
|
||||
return os.path.join(current_root, new_relative_path).replace(os.sep, '/')
|
||||
|
||||
def _add_result(
|
||||
self,
|
||||
result: AutoOrganizeResult,
|
||||
model_name: str,
|
||||
success: bool,
|
||||
message: str
|
||||
) -> None:
|
||||
"""Add a result entry if under the limit"""
|
||||
if len(result.results) < 100: # Limit detailed results
|
||||
result.results.append({
|
||||
"model": model_name,
|
||||
"success": success,
|
||||
"message": message
|
||||
})
|
||||
elif len(result.results) == 100:
|
||||
# Mark as truncated and save sample
|
||||
result.results_truncated = True
|
||||
result.sample_results = result.results[:50]
|
||||
|
||||
async def _cleanup_empty_directories(self, paths: List[str]) -> Dict[str, int]:
|
||||
"""Clean up empty directories after organizing
|
||||
|
||||
Args:
|
||||
paths: List of paths to check for empty directories
|
||||
|
||||
Returns:
|
||||
Dictionary with counts of removed directories by root path
|
||||
"""
|
||||
cleanup_counts = {}
|
||||
for path in paths:
|
||||
removed = remove_empty_dirs(path)
|
||||
cleanup_counts[path] = removed
|
||||
return cleanup_counts
|
||||
|
||||
|
||||
class ModelMoveService:
|
||||
"""Service for handling individual model moves"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
scanner: Model scanner instance
|
||||
"""
|
||||
self.scanner = scanner
|
||||
|
||||
async def move_model(self, file_path: str, target_path: str) -> Dict[str, Any]:
|
||||
"""Move a single model file
|
||||
|
||||
Args:
|
||||
file_path: Source file path
|
||||
target_path: Target directory path
|
||||
|
||||
Returns:
|
||||
Dictionary with move result
|
||||
"""
|
||||
try:
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Source and target directories are the same',
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': file_path
|
||||
}
|
||||
|
||||
new_file_path = await self.scanner.move_model(file_path, target_path)
|
||||
if new_file_path:
|
||||
return {
|
||||
'success': True,
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': new_file_path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'error': 'Failed to move model',
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': None
|
||||
}
|
||||
|
||||
async def move_models_bulk(self, file_paths: List[str], target_path: str) -> Dict[str, Any]:
|
||||
"""Move multiple model files
|
||||
|
||||
Args:
|
||||
file_paths: List of source file paths
|
||||
target_path: Target directory path
|
||||
|
||||
Returns:
|
||||
Dictionary with bulk move results
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
|
||||
for file_path in file_paths:
|
||||
result = await self.move_model(file_path, target_path)
|
||||
results.append({
|
||||
"original_file_path": file_path,
|
||||
"new_file_path": result.get('new_file_path'),
|
||||
"success": result['success'],
|
||||
"message": result.get('message', result.get('error', 'Unknown'))
|
||||
})
|
||||
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||
'results': results,
|
||||
'success_count': success_count,
|
||||
'failure_count': failure_count
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'results': [],
|
||||
'success_count': 0,
|
||||
'failure_count': len(file_paths)
|
||||
}
|
||||
234
py/services/model_hash_index.py
Normal file
234
py/services/model_hash_index.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from typing import Dict, Optional, Set, List
|
||||
import os
|
||||
|
||||
class ModelHashIndex:
|
||||
"""Index for looking up models by hash or filename"""
|
||||
|
||||
def __init__(self):
|
||||
self._hash_to_path: Dict[str, str] = {}
|
||||
self._filename_to_hash: Dict[str, str] = {}
|
||||
# New data structures for tracking duplicates
|
||||
self._duplicate_hashes: Dict[str, List[str]] = {} # sha256 -> list of paths
|
||||
self._duplicate_filenames: Dict[str, List[str]] = {} # filename -> list of paths
|
||||
|
||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||
"""Add or update hash index entry"""
|
||||
if not sha256 or not file_path:
|
||||
return
|
||||
|
||||
# Ensure hash is lowercase for consistency
|
||||
sha256 = sha256.lower()
|
||||
|
||||
# Extract filename without extension
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
|
||||
# Track duplicates by hash
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
if old_path != file_path: # Only record if it's actually a different path
|
||||
if sha256 not in self._duplicate_hashes:
|
||||
self._duplicate_hashes[sha256] = [old_path]
|
||||
if file_path not in self._duplicate_hashes.get(sha256, []):
|
||||
self._duplicate_hashes.setdefault(sha256, []).append(file_path)
|
||||
|
||||
# Track duplicates by filename - FIXED LOGIC
|
||||
if filename in self._filename_to_hash:
|
||||
existing_hash = self._filename_to_hash[filename]
|
||||
existing_path = self._hash_to_path.get(existing_hash)
|
||||
|
||||
# If this is a different file with the same filename
|
||||
if existing_path and existing_path != file_path:
|
||||
# Initialize duplicates tracking if needed
|
||||
if filename not in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [existing_path]
|
||||
|
||||
# Add current file to duplicates if not already present
|
||||
if file_path not in self._duplicate_filenames[filename]:
|
||||
self._duplicate_filenames[filename].append(file_path)
|
||||
|
||||
# Remove old path mapping if hash exists
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
old_filename = self._get_filename_from_path(old_path)
|
||||
if old_filename in self._filename_to_hash and self._filename_to_hash[old_filename] == sha256:
|
||||
del self._filename_to_hash[old_filename]
|
||||
|
||||
# Remove old hash mapping if filename exists and points to different hash
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash != sha256 and old_hash in self._hash_to_path:
|
||||
# Don't delete the old hash mapping, just update filename mapping
|
||||
pass
|
||||
|
||||
# Add new mappings
|
||||
self._hash_to_path[sha256] = file_path
|
||||
self._filename_to_hash[filename] = sha256
|
||||
|
||||
def _get_filename_from_path(self, file_path: str) -> str:
|
||||
"""Extract filename without extension from path"""
|
||||
return os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
def remove_by_path(self, file_path: str, hash_val: str = None) -> None:
|
||||
"""Remove entry by file path"""
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
|
||||
# Find the hash for this file path
|
||||
if hash_val is None:
|
||||
for h, p in self._hash_to_path.items():
|
||||
if p == file_path:
|
||||
hash_val = h
|
||||
break
|
||||
|
||||
# If we didn't find a hash, nothing to do
|
||||
if not hash_val:
|
||||
return
|
||||
|
||||
# Update duplicates tracking for hash
|
||||
if hash_val in self._duplicate_hashes:
|
||||
# Remove the current path from duplicates
|
||||
self._duplicate_hashes[hash_val] = [p for p in self._duplicate_hashes[hash_val] if p != file_path]
|
||||
|
||||
# Update or remove hash mapping based on remaining duplicates
|
||||
if len(self._duplicate_hashes[hash_val]) > 0:
|
||||
# Replace with one of the remaining paths
|
||||
new_path = self._duplicate_hashes[hash_val][0]
|
||||
new_filename = self._get_filename_from_path(new_path)
|
||||
|
||||
# Update hash-to-path mapping
|
||||
self._hash_to_path[hash_val] = new_path
|
||||
|
||||
# IMPORTANT: Update filename-to-hash mapping for consistency
|
||||
# Remove old filename mapping if it points to this hash
|
||||
if filename in self._filename_to_hash and self._filename_to_hash[filename] == hash_val:
|
||||
del self._filename_to_hash[filename]
|
||||
|
||||
# Add new filename mapping
|
||||
self._filename_to_hash[new_filename] = hash_val
|
||||
|
||||
# If only one duplicate left, remove from duplicates tracking
|
||||
if len(self._duplicate_hashes[hash_val]) == 1:
|
||||
del self._duplicate_hashes[hash_val]
|
||||
else:
|
||||
# No duplicates left, remove hash entry completely
|
||||
del self._duplicate_hashes[hash_val]
|
||||
del self._hash_to_path[hash_val]
|
||||
|
||||
# Remove corresponding filename entry if it points to this hash
|
||||
if filename in self._filename_to_hash and self._filename_to_hash[filename] == hash_val:
|
||||
del self._filename_to_hash[filename]
|
||||
else:
|
||||
# No duplicates, simply remove the hash entry
|
||||
del self._hash_to_path[hash_val]
|
||||
|
||||
# Remove corresponding filename entry if it points to this hash
|
||||
if filename in self._filename_to_hash and self._filename_to_hash[filename] == hash_val:
|
||||
del self._filename_to_hash[filename]
|
||||
|
||||
# Update duplicates tracking for filename
|
||||
if filename in self._duplicate_filenames:
|
||||
# Remove the current path from duplicates
|
||||
self._duplicate_filenames[filename] = [p for p in self._duplicate_filenames[filename] if p != file_path]
|
||||
|
||||
# Update or remove filename mapping based on remaining duplicates
|
||||
if len(self._duplicate_filenames[filename]) > 0:
|
||||
# Get the hash for the first remaining duplicate path
|
||||
first_dup_path = self._duplicate_filenames[filename][0]
|
||||
first_dup_hash = None
|
||||
for h, p in self._hash_to_path.items():
|
||||
if p == first_dup_path:
|
||||
first_dup_hash = h
|
||||
break
|
||||
|
||||
# Update the filename to hash mapping if we found a hash
|
||||
if first_dup_hash:
|
||||
self._filename_to_hash[filename] = first_dup_hash
|
||||
|
||||
# If only one duplicate left, remove from duplicates tracking
|
||||
if len(self._duplicate_filenames[filename]) == 1:
|
||||
del self._duplicate_filenames[filename]
|
||||
else:
|
||||
# No duplicates left, remove filename entry completely
|
||||
del self._duplicate_filenames[filename]
|
||||
if filename in self._filename_to_hash:
|
||||
del self._filename_to_hash[filename]
|
||||
|
||||
def remove_by_hash(self, sha256: str) -> None:
|
||||
"""Remove entry by hash"""
|
||||
sha256 = sha256.lower()
|
||||
if sha256 not in self._hash_to_path:
|
||||
return
|
||||
|
||||
# Get the path and filename
|
||||
path = self._hash_to_path[sha256]
|
||||
filename = self._get_filename_from_path(path)
|
||||
|
||||
# Get all paths for this hash (including duplicates)
|
||||
paths_to_remove = [path]
|
||||
if sha256 in self._duplicate_hashes:
|
||||
paths_to_remove.extend(self._duplicate_hashes[sha256])
|
||||
del self._duplicate_hashes[sha256]
|
||||
|
||||
# Remove hash-to-path mapping
|
||||
del self._hash_to_path[sha256]
|
||||
|
||||
# Update filename-to-hash and duplicate filenames for all paths
|
||||
for path_to_remove in paths_to_remove:
|
||||
fname = self._get_filename_from_path(path_to_remove)
|
||||
|
||||
# If this filename maps to the hash we're removing, remove it
|
||||
if fname in self._filename_to_hash and self._filename_to_hash[fname] == sha256:
|
||||
del self._filename_to_hash[fname]
|
||||
|
||||
# Update duplicate filenames tracking
|
||||
if fname in self._duplicate_filenames:
|
||||
self._duplicate_filenames[fname] = [p for p in self._duplicate_filenames[fname] if p != path_to_remove]
|
||||
|
||||
if not self._duplicate_filenames[fname]:
|
||||
del self._duplicate_filenames[fname]
|
||||
elif len(self._duplicate_filenames[fname]) == 1:
|
||||
# If only one entry remains, it's no longer a duplicate
|
||||
del self._duplicate_filenames[fname]
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if hash exists in index"""
|
||||
return sha256.lower() in self._hash_to_path
|
||||
|
||||
def get_path(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a hash"""
|
||||
return self._hash_to_path.get(sha256.lower())
|
||||
|
||||
def get_hash(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a file path"""
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
return self._filename_to_hash.get(filename)
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a filename without extension"""
|
||||
return self._filename_to_hash.get(filename)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all entries"""
|
||||
self._hash_to_path.clear()
|
||||
self._filename_to_hash.clear()
|
||||
self._duplicate_hashes.clear()
|
||||
self._duplicate_filenames.clear()
|
||||
|
||||
def get_all_hashes(self) -> Set[str]:
|
||||
"""Get all hashes in the index"""
|
||||
return set(self._hash_to_path.keys())
|
||||
|
||||
def get_all_filenames(self) -> Set[str]:
|
||||
"""Get all filenames in the index"""
|
||||
return set(self._filename_to_hash.keys())
|
||||
|
||||
def get_duplicate_hashes(self) -> Dict[str, List[str]]:
|
||||
"""Get dictionary of duplicate hashes and their paths"""
|
||||
return self._duplicate_hashes
|
||||
|
||||
def get_duplicate_filenames(self) -> Dict[str, List[str]]:
|
||||
"""Get dictionary of duplicate filenames and their paths"""
|
||||
return self._duplicate_filenames
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of entries"""
|
||||
return len(self._hash_to_path)
|
||||
245
py/services/model_lifecycle_service.py
Normal file
245
py/services/model_lifecycle_service.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Service routines for model lifecycle mutations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete the primary model artefacts within ``target_dir``."""
|
||||
|
||||
patterns = [
|
||||
f"{file_name}.safetensors",
|
||||
f"{file_name}.metadata.json",
|
||||
]
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{file_name}{ext}")
|
||||
|
||||
deleted: List[str] = []
|
||||
main_file = patterns[0]
|
||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
||||
|
||||
if os.path.exists(main_path):
|
||||
os.remove(main_path)
|
||||
deleted.append(main_path)
|
||||
else:
|
||||
logger.warning("Model file not found: %s", main_file)
|
||||
|
||||
for pattern in patterns[1:]:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted.append(pattern)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.warning("Failed to delete %s: %s", pattern, exc)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
class ModelLifecycleService:
|
||||
"""Co-ordinate destructive and mutating model operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scanner,
|
||||
metadata_manager,
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||
) -> None:
|
||||
self._scanner = scanner
|
||||
self._metadata_manager = metadata_manager
|
||||
self._metadata_loader = metadata_loader
|
||||
self._recipe_scanner_factory = (
|
||||
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||
)
|
||||
|
||||
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Delete a model file and associated artefacts."""
|
||||
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||
await cache.resort()
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
return {"success": True, "deleted_files": deleted_files}
|
||||
|
||||
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Mark a model as excluded and prune cache references."""
|
||||
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await self._metadata_loader(metadata_path)
|
||||
metadata["exclude"] = True
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
model_to_remove = next(
|
||||
(item for item in cache.raw_data if item["file_path"] == file_path),
|
||||
None,
|
||||
)
|
||||
|
||||
if model_to_remove:
|
||||
for tag in model_to_remove.get("tags", []):
|
||||
if tag in getattr(self._scanner, "_tags_count", {}):
|
||||
self._scanner._tags_count[tag] = max(
|
||||
0, self._scanner._tags_count[tag] - 1
|
||||
)
|
||||
if self._scanner._tags_count[tag] == 0:
|
||||
del self._scanner._tags_count[tag]
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data if item["file_path"] != file_path
|
||||
]
|
||||
await cache.resort()
|
||||
|
||||
excluded = getattr(self._scanner, "_excluded_models", None)
|
||||
if isinstance(excluded, list):
|
||||
excluded.append(file_path)
|
||||
|
||||
message = f"Model {os.path.basename(file_path)} excluded"
|
||||
return {"success": True, "message": message}
|
||||
|
||||
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
||||
"""Delete a collection of models via the scanner bulk operation."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for deletion")
|
||||
|
||||
return await self._scanner.bulk_delete_models(file_paths)
|
||||
|
||||
async def rename_model(
|
||||
self, *, file_path: str, new_file_name: str
|
||||
) -> Dict[str, object]:
|
||||
"""Rename a model and its companion artefacts."""
|
||||
|
||||
if not file_path or not new_file_name:
|
||||
raise ValueError("File path and new file name are required")
|
||||
|
||||
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
||||
if any(char in new_file_name for char in invalid_chars):
|
||||
raise ValueError("Invalid characters in file name")
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
|
||||
if os.path.exists(new_file_path):
|
||||
raise ValueError("A file with this name already exists")
|
||||
|
||||
patterns = [
|
||||
f"{old_file_name}.safetensors",
|
||||
f"{old_file_name}.metadata.json",
|
||||
f"{old_file_name}.metadata.json.bak",
|
||||
]
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{old_file_name}{ext}")
|
||||
|
||||
existing_files: List[tuple[str, str]] = []
|
||||
for pattern in patterns:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
existing_files.append((path, pattern))
|
||||
|
||||
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||
metadata: Optional[Dict[str, object]] = None
|
||||
hash_value: Optional[str] = None
|
||||
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = await self._metadata_loader(metadata_path)
|
||||
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
||||
|
||||
renamed_files: List[str] = []
|
||||
new_metadata_path: Optional[str] = None
|
||||
new_preview: Optional[str] = None
|
||||
|
||||
for old_path, pattern in existing_files:
|
||||
ext = self._get_multipart_ext(pattern)
|
||||
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
os.rename(old_path, new_path)
|
||||
renamed_files.append(new_path)
|
||||
|
||||
if ext == ".metadata.json":
|
||||
new_metadata_path = new_path
|
||||
|
||||
if metadata and new_metadata_path:
|
||||
metadata["file_name"] = new_file_name
|
||||
metadata["file_path"] = new_file_path
|
||||
|
||||
if metadata.get("preview_url"):
|
||||
old_preview = str(metadata["preview_url"])
|
||||
ext = self._get_multipart_ext(old_preview)
|
||||
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
metadata["preview_url"] = new_preview
|
||||
|
||||
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
||||
|
||||
if metadata:
|
||||
await self._scanner.update_single_model_cache(
|
||||
file_path, new_file_path, metadata
|
||||
)
|
||||
|
||||
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
||||
recipe_scanner = await self._recipe_scanner_factory()
|
||||
if recipe_scanner:
|
||||
try:
|
||||
await recipe_scanner.update_lora_filename_by_hash(
|
||||
hash_value, new_file_name
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"Error updating recipe references for %s: %s",
|
||||
file_path,
|
||||
exc,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"new_file_path": new_file_path,
|
||||
"new_preview_path": new_preview,
|
||||
"renamed_files": renamed_files,
|
||||
"reload_required": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_multipart_ext(filename: str) -> str:
|
||||
"""Return the extension for files with compound suffixes."""
|
||||
|
||||
parts = filename.split(".")
|
||||
if len(parts) == 3:
|
||||
return "." + ".".join(parts[-2:])
|
||||
if len(parts) >= 4:
|
||||
return "." + ".".join(parts[-3:])
|
||||
return os.path.splitext(filename)[1]
|
||||
|
||||
495
py/services/model_metadata_provider.py
Normal file
495
py/services/model_metadata_provider.py
Normal file
@@ -0,0 +1,495 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError as exc:
|
||||
BeautifulSoup = None # type: ignore[assignment]
|
||||
_BS4_IMPORT_ERROR = exc
|
||||
else:
|
||||
_BS4_IMPORT_ERROR = None
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as exc:
|
||||
aiosqlite = None # type: ignore[assignment]
|
||||
_AIOSQLITE_IMPORT_ERROR = exc
|
||||
else:
|
||||
_AIOSQLITE_IMPORT_ERROR = None
|
||||
|
||||
def _require_beautifulsoup() -> Any:
|
||||
if BeautifulSoup is None:
|
||||
raise RuntimeError(
|
||||
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
|
||||
"Install it with 'pip install beautifulsoup4'."
|
||||
) from _BS4_IMPORT_ERROR
|
||||
return BeautifulSoup
|
||||
|
||||
def _require_aiosqlite() -> Any:
|
||||
if aiosqlite is None:
|
||||
raise RuntimeError(
|
||||
"aiosqlite is required for SQLiteModelMetadataProvider. "
|
||||
"Install it with 'pip install aiosqlite'."
|
||||
) from _AIOSQLITE_IMPORT_ERROR
|
||||
return aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelMetadataProvider(ABC):
|
||||
"""Base abstract class for all model metadata providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model with their details"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata"""
|
||||
pass
|
||||
|
||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses Civitai API for metadata"""
|
||||
|
||||
def __init__(self, civitai_client):
|
||||
self.client = civitai_client
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
return await self.client.get_model_versions(model_id)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
return await self.client.get_model_version(model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None, "CivArchive provider does not support hash lookup"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version by parsing CivArchive HTML page"""
|
||||
if model_id is None or version_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Construct CivArchive URL
|
||||
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
html_content = await response.text()
|
||||
|
||||
# Parse HTML to extract JSON data
|
||||
soup_parser = _require_beautifulsoup()
|
||||
soup = soup_parser(html_content, 'html.parser')
|
||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||
|
||||
if not script_tag:
|
||||
return None
|
||||
|
||||
# Parse JSON content
|
||||
json_data = json.loads(script_tag.string)
|
||||
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
|
||||
|
||||
if not model_data or 'version' not in model_data:
|
||||
return None
|
||||
|
||||
# Extract version data as base
|
||||
version = model_data['version'].copy()
|
||||
|
||||
# Restructure stats
|
||||
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
|
||||
version['stats'] = {
|
||||
'downloadCount': version.pop('downloadCount'),
|
||||
'ratingCount': version.pop('ratingCount'),
|
||||
'rating': version.pop('rating')
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords
|
||||
if 'trigger' in version:
|
||||
version['trainedWords'] = version.pop('trigger')
|
||||
|
||||
# Transform files data to expected format
|
||||
if 'files' in version:
|
||||
transformed_files = []
|
||||
for file_data in version['files']:
|
||||
# Find first available mirror (deletedAt is null)
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
# Create transformed file entry
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else None,
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
version['files'] = transformed_files
|
||||
|
||||
# Add model information
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('is_nsfw', False),
|
||||
'description': model_data.get('description'),
|
||||
'tags': model_data.get('tags', [])
|
||||
}
|
||||
|
||||
version['creator'] = {
|
||||
'username': model_data.get('username'),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Add source identifier
|
||||
version['source'] = 'civarchive'
|
||||
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
|
||||
|
||||
return version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
|
||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses SQLite database for metadata"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._aiosqlite = _require_aiosqlite()
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
# Look up in model_files table to get model_id and version_id
|
||||
query = """
|
||||
SELECT model_id, version_id
|
||||
FROM model_files
|
||||
WHERE sha256 = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
cursor = await db.execute(query, (model_hash.upper(),))
|
||||
file_row = await cursor.fetchone()
|
||||
|
||||
if not file_row:
|
||||
return None, "Model not found"
|
||||
|
||||
# Get version details
|
||||
model_id = file_row['model_id']
|
||||
version_id = file_row['version_id']
|
||||
|
||||
# Build response in the same format as Civitai API
|
||||
result = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return result, None if result else "Error retrieving model data"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# First check if model exists
|
||||
model_query = "SELECT * FROM models WHERE id = ?"
|
||||
cursor = await db.execute(model_query, (model_id,))
|
||||
model_row = await cursor.fetchone()
|
||||
|
||||
if not model_row:
|
||||
return None
|
||||
|
||||
model_data = json.loads(model_row['data'])
|
||||
model_type = model_row['type']
|
||||
model_name = model_row['name']
|
||||
|
||||
# Get all versions for this model
|
||||
versions_query = """
|
||||
SELECT id, name, base_model, data, position, published_at
|
||||
FROM model_versions
|
||||
WHERE model_id = ?
|
||||
ORDER BY position ASC
|
||||
"""
|
||||
cursor = await db.execute(versions_query, (model_id,))
|
||||
version_rows = await cursor.fetchall()
|
||||
|
||||
if not version_rows:
|
||||
return {'modelVersions': [], 'type': model_type}
|
||||
|
||||
# Format versions similar to Civitai API
|
||||
model_versions = []
|
||||
for row in version_rows:
|
||||
version_data = json.loads(row['data'])
|
||||
# Add fields from the row to ensure we have the basic fields
|
||||
version_entry = {
|
||||
'id': row['id'],
|
||||
'modelId': int(model_id),
|
||||
'name': row['name'],
|
||||
'baseModel': row['base_model'],
|
||||
'model': {
|
||||
'name': model_row['name'],
|
||||
'type': model_type,
|
||||
},
|
||||
'source': 'archive_db'
|
||||
}
|
||||
# Update with any additional data
|
||||
version_entry.update(version_data)
|
||||
model_versions.append(version_entry)
|
||||
|
||||
return {
|
||||
'modelVersions': model_versions,
|
||||
'type': model_type,
|
||||
'name': model_name
|
||||
}
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata from SQLite database"""
|
||||
if not model_id and not version_id:
|
||||
return None
|
||||
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
cursor = await db.execute(version_query, (version_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None
|
||||
|
||||
model_id = version_row['model_id']
|
||||
|
||||
# Case 2: model_id is provided but version_id is not
|
||||
elif model_id is not None and version_id is None:
|
||||
# Find the latest version
|
||||
version_query = """
|
||||
SELECT id FROM model_versions
|
||||
WHERE model_id = ?
|
||||
ORDER BY position ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor = await db.execute(version_query, (model_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None
|
||||
|
||||
version_id = version_row['id']
|
||||
|
||||
# Now we have both model_id and version_id, get the full data
|
||||
return await self._get_version_with_model_data(db, model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Get version details
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
cursor = await db.execute(version_query, (version_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None, "Model version not found"
|
||||
|
||||
model_id = version_row['model_id']
|
||||
|
||||
# Build complete version data with model info
|
||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return version_data, None
|
||||
|
||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||
"""Helper to build version data with model information"""
|
||||
# Get version details
|
||||
version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?"
|
||||
cursor = await db.execute(version_query, (version_id, model_id))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None
|
||||
|
||||
# Get model details
|
||||
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
||||
cursor = await db.execute(model_query, (model_id,))
|
||||
model_row = await cursor.fetchone()
|
||||
|
||||
if not model_row:
|
||||
return None
|
||||
|
||||
# Parse JSON data
|
||||
try:
|
||||
version_data = json.loads(version_row['data'])
|
||||
model_data = json.loads(model_row['data'])
|
||||
|
||||
# Build response
|
||||
result = {
|
||||
"id": int(version_id),
|
||||
"modelId": int(model_id),
|
||||
"name": version_row['name'],
|
||||
"baseModel": version_row['base_model'],
|
||||
"model": {
|
||||
"name": model_row['name'],
|
||||
"description": model_data.get("description"),
|
||||
"type": model_row['type'],
|
||||
"tags": model_data.get("tags", [])
|
||||
},
|
||||
"creator": {
|
||||
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
||||
"image": model_data.get("creator", {}).get("image")
|
||||
},
|
||||
"source": "archive_db"
|
||||
}
|
||||
|
||||
# Add any additional fields from version data
|
||||
result.update(version_data)
|
||||
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
"""Try providers in order, return first successful result."""
|
||||
def __init__(self, providers: list):
|
||||
self.providers = providers
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result, error = await provider.get_model_by_hash(model_hash)
|
||||
if result:
|
||||
return result, error
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_by_hash: {e}")
|
||||
continue
|
||||
return None, "Model not found"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_model_versions(model_id)
|
||||
if result:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_versions: {e}")
|
||||
continue
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_model_version(model_id, version_id)
|
||||
if result:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_version: {e}")
|
||||
continue
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result, error = await provider.get_model_version_info(version_id)
|
||||
if result:
|
||||
return result, error
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_version_info: {e}")
|
||||
continue
|
||||
return None, "No provider could retrieve the data"
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of ModelMetadataProviderManager"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.providers = {}
|
||||
self.default_provider = None
|
||||
|
||||
def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False):
|
||||
"""Register a metadata provider"""
|
||||
self.providers[name] = provider
|
||||
if is_default or self.default_provider is None:
|
||||
self.default_provider = name
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]:
|
||||
"""Get model versions using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_versions(model_id)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
|
||||
"""Get specific model version using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version(model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version info using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version_info(version_id)
|
||||
|
||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||
"""Get provider by name or default provider"""
|
||||
if provider_name and provider_name in self.providers:
|
||||
return self.providers[provider_name]
|
||||
|
||||
if self.default_provider is None:
|
||||
raise ValueError("No default provider set and no valid provider specified")
|
||||
|
||||
return self.providers[self.default_provider]
|
||||
196
py/services/model_query.py
Normal file
196
py/services/model_query.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||
|
||||
|
||||
class SettingsProvider(Protocol):
|
||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SortParams:
|
||||
"""Normalized representation of sorting instructions."""
|
||||
|
||||
key: str
|
||||
order: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FilterCriteria:
|
||||
"""Container for model list filtering options."""
|
||||
|
||||
folder: Optional[str] = None
|
||||
base_models: Optional[Sequence[str]] = None
|
||||
tags: Optional[Sequence[str]] = None
|
||||
favorites_only: bool = False
|
||||
search_options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModelCacheRepository:
|
||||
"""Adapter around scanner cache access and sort normalisation."""
|
||||
|
||||
def __init__(self, scanner) -> None:
|
||||
self._scanner = scanner
|
||||
|
||||
async def get_cache(self):
|
||||
"""Return the underlying cache instance from the scanner."""
|
||||
return await self._scanner.get_cached_data()
|
||||
|
||||
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
|
||||
"""Fetch cached data pre-sorted according to ``params``."""
|
||||
cache = await self.get_cache()
|
||||
return await cache.get_sorted_data(params.key, params.order)
|
||||
|
||||
@staticmethod
|
||||
def parse_sort(sort_by: str) -> SortParams:
|
||||
"""Parse an incoming sort string into key/order primitives."""
|
||||
if not sort_by:
|
||||
return SortParams(key="name", order="asc")
|
||||
|
||||
if ":" in sort_by:
|
||||
raw_key, raw_order = sort_by.split(":", 1)
|
||||
sort_key = raw_key.strip().lower() or "name"
|
||||
order = raw_order.strip().lower()
|
||||
else:
|
||||
sort_key = sort_by.strip().lower() or "name"
|
||||
order = "asc"
|
||||
|
||||
if order not in ("asc", "desc"):
|
||||
order = "asc"
|
||||
|
||||
return SortParams(key=sort_key, order=order)
|
||||
|
||||
|
||||
class ModelFilterSet:
|
||||
"""Applies common filtering rules to the model collection."""
|
||||
|
||||
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
||||
self._settings = settings
|
||||
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
||||
|
||||
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||
"""Return items that satisfy the provided criteria."""
|
||||
items = list(data)
|
||||
|
||||
if self._settings.get("show_only_sfw", False):
|
||||
threshold = self._nsfw_levels.get("R", 0)
|
||||
items = [
|
||||
item for item in items
|
||||
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||
]
|
||||
|
||||
if criteria.favorites_only:
|
||||
items = [item for item in items if item.get("favorite", False)]
|
||||
|
||||
folder = criteria.folder
|
||||
options = criteria.search_options or {}
|
||||
recursive = bool(options.get("recursive", True))
|
||||
if folder is not None:
|
||||
if recursive:
|
||||
if folder:
|
||||
folder_with_sep = f"{folder}/"
|
||||
items = [
|
||||
item for item in items
|
||||
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
||||
]
|
||||
else:
|
||||
items = [item for item in items if item.get("folder") == folder]
|
||||
|
||||
base_models = criteria.base_models or []
|
||||
if base_models:
|
||||
base_model_set = set(base_models)
|
||||
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||
|
||||
tags = criteria.tags or []
|
||||
if tags:
|
||||
tag_set = set(tags)
|
||||
items = [
|
||||
item for item in items
|
||||
if any(tag in tag_set for tag in item.get("tags", []))
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
class SearchStrategy:
|
||||
"""Encapsulates text and fuzzy matching behaviour for model queries."""
|
||||
|
||||
DEFAULT_OPTIONS: Dict[str, Any] = {
|
||||
"filename": True,
|
||||
"modelname": True,
|
||||
"tags": False,
|
||||
"recursive": True,
|
||||
"creator": False,
|
||||
}
|
||||
|
||||
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
||||
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
||||
|
||||
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge provided options with defaults without mutating input."""
|
||||
normalized = dict(self.DEFAULT_OPTIONS)
|
||||
if options:
|
||||
normalized.update(options)
|
||||
return normalized
|
||||
|
||||
def apply(
|
||||
self,
|
||||
data: Iterable[Dict[str, Any]],
|
||||
search_term: str,
|
||||
options: Dict[str, Any],
|
||||
fuzzy: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return items matching the search term using the configured strategy."""
|
||||
if not search_term:
|
||||
return list(data)
|
||||
|
||||
search_lower = search_term.lower()
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
for item in data:
|
||||
if options.get("filename", True):
|
||||
candidate = item.get("file_name", "")
|
||||
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("modelname", True):
|
||||
candidate = item.get("model_name", "")
|
||||
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("tags", False):
|
||||
tags = item.get("tags", []) or []
|
||||
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("creator", False):
|
||||
creator_username = ""
|
||||
civitai = item.get("civitai")
|
||||
if isinstance(civitai, dict):
|
||||
creator = civitai.get("creator")
|
||||
if isinstance(creator, dict):
|
||||
creator_username = creator.get("username", "")
|
||||
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
|
||||
candidate_lower = candidate.lower()
|
||||
if fuzzy:
|
||||
return self._fuzzy_match(candidate, search_term)
|
||||
return search_lower in candidate_lower
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user