mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
Compare commits
601 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12c88835f2 | ||
|
|
6f4453aaf3 | ||
|
|
4b4b8fe3c1 | ||
|
|
49e7c2e9f5 | ||
|
|
4653c273e3 | ||
|
|
ae145de2f2 | ||
|
|
dde7cf71c6 | ||
|
|
219cd242db | ||
|
|
e5b712c082 | ||
|
|
4d2c60d59b | ||
|
|
1d2c1b114b | ||
|
|
2bde936d05 | ||
|
|
cd3e32bf4b | ||
|
|
454536d631 | ||
|
|
656f1755fd | ||
|
|
8aa76ce5c1 | ||
|
|
49fa37f00d | ||
|
|
9f83548cf3 | ||
|
|
6054d95e85 | ||
|
|
8c9bb35824 | ||
|
|
3eacf9558a | ||
|
|
fee37172b4 | ||
|
|
e128c80eb1 | ||
|
|
5cc735ed57 | ||
|
|
43fcce6361 | ||
|
|
49b7126278 | ||
|
|
679cfb5c69 | ||
|
|
50616bc680 | ||
|
|
aaad270822 | ||
|
|
bd10280736 | ||
|
|
d477050239 | ||
|
|
85f79cd8d1 | ||
|
|
613cd81152 | ||
|
|
e0aba6c49a | ||
|
|
d78bcf2494 | ||
|
|
f7cffd2eba | ||
|
|
0d0b91aa80 | ||
|
|
42872e6d2d | ||
|
|
b91f06405d | ||
|
|
dac4c688d6 | ||
|
|
097a68ad18 | ||
|
|
4a98710db0 | ||
|
|
d033a374dd | ||
|
|
6aa23fe36a | ||
|
|
3220cfb79c | ||
|
|
b92e7aa446 | ||
|
|
c3b9c73541 | ||
|
|
81c6672880 | ||
|
|
08baf884d3 | ||
|
|
1c4096f3d5 | ||
|
|
66a3f3f59a | ||
|
|
624df1328b | ||
|
|
c063854b51 | ||
|
|
8cf99dd928 | ||
|
|
c07e885725 | ||
|
|
21772feadd | ||
|
|
2d00cfdd31 | ||
|
|
49e03d658b | ||
|
|
fec85bcc08 | ||
|
|
0e93a6bcb0 | ||
|
|
7e20f738fb | ||
|
|
24090e6077 | ||
|
|
1022b07f64 | ||
|
|
4faf912c6f | ||
|
|
56e4b24b07 | ||
|
|
12295d2fdc | ||
|
|
6261f7d18d | ||
|
|
9e1a2e3bb7 | ||
|
|
40cbb2155c | ||
|
|
a8d7070832 | ||
|
|
ab7266f3a4 | ||
|
|
3053b13fcb | ||
|
|
f3544b3471 | ||
|
|
1610048974 | ||
|
|
fc6f1bf95b | ||
|
|
67b274c1b2 | ||
|
|
fb0d6b5641 | ||
|
|
d30fbeb286 | ||
|
|
46e430ebbb | ||
|
|
bc4cd45fcb | ||
|
|
bdc86ddf15 | ||
|
|
ded17c1479 | ||
|
|
933e2fc01d | ||
|
|
1cddeee264 | ||
|
|
183c000080 | ||
|
|
adf7b6d4b2 | ||
|
|
0566d50346 | ||
|
|
4275dc3003 | ||
|
|
30956aeefc | ||
|
|
64e1dd3dd6 | ||
|
|
0dc4b6f728 | ||
|
|
86074c87d7 | ||
|
|
6f9245df01 | ||
|
|
4540e47055 | ||
|
|
4bb8981e78 | ||
|
|
c49be91aa0 | ||
|
|
2b847039d4 | ||
|
|
1147725fd7 | ||
|
|
26891e12a4 | ||
|
|
2f7e44a76f | ||
|
|
9366d3d2d0 | ||
|
|
6b606a5cc8 | ||
|
|
e5339c178a | ||
|
|
1a76f74482 | ||
|
|
13f13eb095 | ||
|
|
125fdecd61 | ||
|
|
d05076d258 | ||
|
|
00b77581fc | ||
|
|
897787d17c | ||
|
|
d5a280cf2b | ||
|
|
a0c2d9b5ad | ||
|
|
e713bd1ca2 | ||
|
|
beb8ff1dd1 | ||
|
|
6a8f0867d9 | ||
|
|
51ad1c9a33 | ||
|
|
34872eb612 | ||
|
|
8b4e3128ff | ||
|
|
c66cbc800b | ||
|
|
21941521a0 | ||
|
|
0d33884052 | ||
|
|
415df49377 | ||
|
|
f5f45002c7 | ||
|
|
1edf7126bb | ||
|
|
a1a55a1002 | ||
|
|
45f5cb46bd | ||
|
|
1b5e608a27 | ||
|
|
a7df8ae15c | ||
|
|
47ce0d0fe2 | ||
|
|
b220e288d0 | ||
|
|
1fc8b45b68 | ||
|
|
62f06302f0 | ||
|
|
3e5cb223f3 | ||
|
|
4ee5b7481c | ||
|
|
e104b78c01 | ||
|
|
ba1ac58721 | ||
|
|
a4fbeb6295 | ||
|
|
68f8871403 | ||
|
|
6fd74952b7 | ||
|
|
1ea468cfc4 | ||
|
|
14721c265f | ||
|
|
821827a375 | ||
|
|
9ba3e2c204 | ||
|
|
d287883671 | ||
|
|
ead34818db | ||
|
|
a060010b96 | ||
|
|
76a92ac847 | ||
|
|
74bc490383 | ||
|
|
510d476323 | ||
|
|
1e7257fd53 | ||
|
|
4ff1f51b1c | ||
|
|
74507cef05 | ||
|
|
c23ab04d90 | ||
|
|
d50dde6cf6 | ||
|
|
fcb1fb39be | ||
|
|
b0ef74f802 | ||
|
|
f332aef41d | ||
|
|
1f91a3da8e | ||
|
|
16840c321d | ||
|
|
c109e392ad | ||
|
|
5e69671366 | ||
|
|
52d23d9b75 | ||
|
|
4c4e6d7a7b | ||
|
|
03b6e78705 | ||
|
|
24c01141d7 | ||
|
|
6dc2811af4 | ||
|
|
e6425dce32 | ||
|
|
95e2ff5f1e | ||
|
|
92ac487128 | ||
|
|
3250fa89cb | ||
|
|
7475de366b | ||
|
|
affb507b37 | ||
|
|
3320b80150 | ||
|
|
fb2b69b787 | ||
|
|
29a05f6533 | ||
|
|
9fa3fac973 | ||
|
|
904b0d104a | ||
|
|
1d31dae110 | ||
|
|
476ecb7423 | ||
|
|
4eb67cf6da | ||
|
|
a5a9f7ed83 | ||
|
|
c0b029e228 | ||
|
|
9bebcc9a4b | ||
|
|
ac7d23011c | ||
|
|
491e09b7b5 | ||
|
|
192bc237bf | ||
|
|
f041f4a114 | ||
|
|
2546580377 | ||
|
|
8fbf2ab56d | ||
|
|
ea727aad2e | ||
|
|
5520aecbba | ||
|
|
6b738a4769 | ||
|
|
903a8050b3 | ||
|
|
31b032429d | ||
|
|
2bcf341f04 | ||
|
|
ca6f45b359 | ||
|
|
2a67cec16b | ||
|
|
1800afe31b | ||
|
|
8c6311355d | ||
|
|
91801dff85 | ||
|
|
be594133f0 | ||
|
|
8a538d117e | ||
|
|
8d9118cbee | ||
|
|
b67464ea13 | ||
|
|
33334da0bb | ||
|
|
40ce2baa7b | ||
|
|
1134466cc0 | ||
|
|
92341111ad | ||
|
|
4956d6781f | ||
|
|
63562240c4 | ||
|
|
84d801cf14 | ||
|
|
b56fe4ca68 | ||
|
|
6c83c65e02 | ||
|
|
a83f020fcc | ||
|
|
7f9a3bf272 | ||
|
|
f80e266d02 | ||
|
|
7bef562541 | ||
|
|
b2428f607c | ||
|
|
8303196b57 | ||
|
|
987b8c8742 | ||
|
|
e60a579b85 | ||
|
|
be8edafed0 | ||
|
|
a258a18fa4 | ||
|
|
59010ca431 | ||
|
|
75f3764e6c | ||
|
|
867ffd1163 | ||
|
|
6acccbbb94 | ||
|
|
b2c4efab45 | ||
|
|
408a435b71 | ||
|
|
36d3cd93d5 | ||
|
|
b36fea002e | ||
|
|
52acbd954a | ||
|
|
f6709a55c3 | ||
|
|
7b374d747b | ||
|
|
fd480a9360 | ||
|
|
ec8b228867 | ||
|
|
401200050b | ||
|
|
29160bd6e5 | ||
|
|
3c9e402bc0 | ||
|
|
ff4d0f0208 | ||
|
|
f82908221c | ||
|
|
4246908f2e | ||
|
|
f64597afd2 | ||
|
|
975ff2672d | ||
|
|
e90ba31784 | ||
|
|
a4074c93bc | ||
|
|
7a8b7598c7 | ||
|
|
cd0d832f14 | ||
|
|
5b0becaaf2 | ||
|
|
9817bac2fe | ||
|
|
f6bd48cfcd | ||
|
|
01843b8f2b | ||
|
|
94ed81de5e | ||
|
|
0700b8f399 | ||
|
|
d62cff9841 | ||
|
|
083f4805b2 | ||
|
|
8e5bfd379e | ||
|
|
2366f143d8 | ||
|
|
e997f5bc1b | ||
|
|
842beec7cc | ||
|
|
d2268fc9e0 | ||
|
|
a98e26139f | ||
|
|
522a3ea88b | ||
|
|
d7949fbc30 | ||
|
|
6df083a1d5 | ||
|
|
4dc80e7f6e | ||
|
|
c2a8508513 | ||
|
|
159193ef43 | ||
|
|
1f37ffb105 | ||
|
|
919fed05c5 | ||
|
|
1814f83bee | ||
|
|
1823840456 | ||
|
|
623c28bfc3 | ||
|
|
3079131337 | ||
|
|
a34ade0120 | ||
|
|
e9ada70088 | ||
|
|
597cc48248 | ||
|
|
ec3f857ef1 | ||
|
|
383b4de539 | ||
|
|
1bf9326604 | ||
|
|
d9f5459d46 | ||
|
|
e45a1b1e19 | ||
|
|
331ad8f644 | ||
|
|
52fa88b04c | ||
|
|
8895a64d24 | ||
|
|
fdec535559 | ||
|
|
6c5559ae2d | ||
|
|
9f54622b17 | ||
|
|
03b6f4b378 | ||
|
|
af4cbe2332 | ||
|
|
141f72963a | ||
|
|
3d3c66e12f | ||
|
|
ee84571bdb | ||
|
|
6500936aad | ||
|
|
32d2b6c013 | ||
|
|
05df40977d | ||
|
|
5d7a1dcde5 | ||
|
|
9c45d9db6c | ||
|
|
ca692ed0f2 | ||
|
|
af499565d3 | ||
|
|
fe2d7e3a9e | ||
|
|
9f69822221 | ||
|
|
bb43f047c2 | ||
|
|
2356662492 | ||
|
|
1624a45093 | ||
|
|
dcb9983786 | ||
|
|
83d1828905 | ||
|
|
6a281cf3ee | ||
|
|
ed1cd39a6c | ||
|
|
dda19b3920 | ||
|
|
25139ca922 | ||
|
|
3cd57a582c | ||
|
|
d3903ac655 | ||
|
|
199e374318 | ||
|
|
8375c1413d | ||
|
|
9e268cf016 | ||
|
|
112b3abc26 | ||
|
|
a8331a2357 | ||
|
|
52e3ad08c1 | ||
|
|
8d01d04ef0 | ||
|
|
a141384907 | ||
|
|
b8aa7184bd | ||
|
|
e4195f874d | ||
|
|
d04deff5ca | ||
|
|
20ce0778a0 | ||
|
|
5a0b3470f1 | ||
|
|
a920921570 | ||
|
|
286f4ff384 | ||
|
|
71ddfafa98 | ||
|
|
b7e3e53697 | ||
|
|
16df548b77 | ||
|
|
425c33ae00 | ||
|
|
c9289ed2dc | ||
|
|
96517cbdef | ||
|
|
b03420faac | ||
|
|
65a1aa7ca2 | ||
|
|
3a92e8eaf9 | ||
|
|
a8dc50d64a | ||
|
|
3397cc7d8d | ||
|
|
c3e8131b24 | ||
|
|
f8ca8584ae | ||
|
|
3050bbe260 | ||
|
|
e1dda2795a | ||
|
|
6d8408e626 | ||
|
|
0906271aa9 | ||
|
|
4c33c9d256 | ||
|
|
fa9c78209f | ||
|
|
6678ec8a60 | ||
|
|
854e467c12 | ||
|
|
e6b94c7b21 | ||
|
|
2c6f9d8602 | ||
|
|
c74033b9c0 | ||
|
|
d2b21d27bb | ||
|
|
215272469f | ||
|
|
f7d05ab0f1 | ||
|
|
6f2ad2be77 | ||
|
|
66575c719a | ||
|
|
677a239d53 | ||
|
|
3b96bfe5af | ||
|
|
83be5cfa64 | ||
|
|
6b834c2362 | ||
|
|
7abfc49e08 | ||
|
|
65d5f50088 | ||
|
|
4f1f4ffe3d | ||
|
|
b0c2027a1c | ||
|
|
33c83358b0 | ||
|
|
31223f0526 | ||
|
|
92daadb92c | ||
|
|
fae2e274fd | ||
|
|
342a722991 | ||
|
|
65ec6aacb7 | ||
|
|
9387470c69 | ||
|
|
31f6edf8f0 | ||
|
|
487b062175 | ||
|
|
d8e13de096 | ||
|
|
e8a30088ef | ||
|
|
bf7b07ba74 | ||
|
|
28fe3e7b7a | ||
|
|
c0eff2bb5e | ||
|
|
848c1741fe | ||
|
|
1370b8e8c1 | ||
|
|
82a068e610 | ||
|
|
32f42bafaa | ||
|
|
4081b7f022 | ||
|
|
a5808193a6 | ||
|
|
854ca322c1 | ||
|
|
c1d9b5137a | ||
|
|
f33d5745b3 | ||
|
|
d89c2ca128 | ||
|
|
835584cc85 | ||
|
|
b2ffbe3a68 | ||
|
|
defcc79e6c | ||
|
|
c06d9f84f0 | ||
|
|
fe57a8e156 | ||
|
|
b77105795a | ||
|
|
e2df5fcf27 | ||
|
|
836a64e728 | ||
|
|
08ba0c9f42 | ||
|
|
6fcc6a5299 | ||
|
|
6dd58248c6 | ||
|
|
2786801b71 | ||
|
|
ea29cbeb7a | ||
|
|
3cf9121a8c | ||
|
|
381bd3938a | ||
|
|
e4ce384023 | ||
|
|
12d1857b13 | ||
|
|
0d9003dea4 | ||
|
|
1a3751acfa | ||
|
|
c5a3af2399 | ||
|
|
ea8a64fafc | ||
|
|
981e367bf1 | ||
|
|
a3d6e62035 | ||
|
|
7f205cdcc8 | ||
|
|
e587189880 | ||
|
|
206c1bd69f | ||
|
|
a7d9255c2c | ||
|
|
08265a85ec | ||
|
|
1ed5630464 | ||
|
|
c784615f11 | ||
|
|
26d51b1190 | ||
|
|
d83fad6abc | ||
|
|
692796db46 | ||
|
|
f15c6f33f9 | ||
|
|
dda9eb4d7c | ||
|
|
6f3aeb61e7 | ||
|
|
d6145e633f | ||
|
|
07014d98ce | ||
|
|
e8ccdabe6c | ||
|
|
cf9fd2d5c2 | ||
|
|
bf9aa9356b | ||
|
|
68d00ce289 | ||
|
|
5288021e4f | ||
|
|
4d38add291 | ||
|
|
804808da4a | ||
|
|
298a95432d | ||
|
|
a834fc4b30 | ||
|
|
2c6c9542dd | ||
|
|
a9a7f4c8ec | ||
|
|
ea9370443d | ||
|
|
c2e00b240e | ||
|
|
a2b81ea099 | ||
|
|
ee609e8eac | ||
|
|
e04ef671e9 | ||
|
|
0184dfd7eb | ||
|
|
eccfa0ca54 | ||
|
|
6d3feb4bef | ||
|
|
29d2b5ee4b | ||
|
|
c82fabb67f | ||
|
|
fcfc868e57 | ||
|
|
67b403f8ca | ||
|
|
de06c6b2f6 | ||
|
|
fa444dfb8a | ||
|
|
124002a472 | ||
|
|
0c883433c1 | ||
|
|
bcf3b2cf55 | ||
|
|
357c4e9c08 | ||
|
|
9edfc68e91 | ||
|
|
8c06cb3e80 | ||
|
|
144fa0a6d4 | ||
|
|
25d5a1541e | ||
|
|
a579d36389 | ||
|
|
d766dac341 | ||
|
|
b15ef1bbc6 | ||
|
|
3e52e00597 | ||
|
|
f749dd0d52 | ||
|
|
48a8a42108 | ||
|
|
db7f57a5a4 | ||
|
|
556381b983 | ||
|
|
158d7d5898 | ||
|
|
18844da95d | ||
|
|
7e0df4d718 | ||
|
|
0dbb76e8c8 | ||
|
|
f73b3422a6 | ||
|
|
bd95e802ec | ||
|
|
5de16a78c5 | ||
|
|
6f8e09fcde | ||
|
|
f54d480f03 | ||
|
|
e68b213fb3 | ||
|
|
132334d500 | ||
|
|
a6f04c6d7e | ||
|
|
854e8bf356 | ||
|
|
6ff883d2d3 | ||
|
|
849b97afba | ||
|
|
1bd2635864 | ||
|
|
79ab0f7b6c | ||
|
|
79011bd257 | ||
|
|
c692713ffb | ||
|
|
df9b554ce1 | ||
|
|
277a8e4682 | ||
|
|
acb52dba09 | ||
|
|
8f10765254 | ||
|
|
0653f59473 | ||
|
|
7a4b5a4667 | ||
|
|
49c4a4068b | ||
|
|
40ad590046 | ||
|
|
30374ae3e6 | ||
|
|
ab22d16bad | ||
|
|
971cd56a4a | ||
|
|
d7cb546c5f | ||
|
|
9d8b7344cd | ||
|
|
2d4f6ae7ce | ||
|
|
d9126807b0 | ||
|
|
cad5fb3fba | ||
|
|
afe23ad6b7 | ||
|
|
fc4327087b | ||
|
|
71762d788f | ||
|
|
6472e00fb0 | ||
|
|
4043846767 | ||
|
|
d3b2bc962c | ||
|
|
54f7b64821 | ||
|
|
82a2a6e669 | ||
|
|
6376d60af5 | ||
|
|
b1e2e3831f | ||
|
|
5de1c8aa82 | ||
|
|
63dc5c2bdb | ||
|
|
7f2d1670a0 | ||
|
|
53c8c337fc | ||
|
|
5b4ec1b2a2 | ||
|
|
64dd2ed141 | ||
|
|
eb57e04e95 | ||
|
|
ae905c8630 | ||
|
|
c157e794f0 | ||
|
|
ed9bae6f6a | ||
|
|
9fe1ce19ad | ||
|
|
6148236cbd | ||
|
|
2471eb518a | ||
|
|
8931b41c76 | ||
|
|
7f523f167d | ||
|
|
446b6d6158 | ||
|
|
2ee057e19b | ||
|
|
afc810f21f | ||
|
|
357052a903 | ||
|
|
39d6d8d04a | ||
|
|
888896c0c0 | ||
|
|
ceee482ecc | ||
|
|
d0ed1213d8 | ||
|
|
f6ef428008 | ||
|
|
e726c4f442 | ||
|
|
402318e586 | ||
|
|
b198cc2a6e | ||
|
|
c3dd4da11b | ||
|
|
ba2e42b06e | ||
|
|
fa0902dc74 | ||
|
|
8fcb6083dc | ||
|
|
1ef88140e3 | ||
|
|
aa34c4c84c | ||
|
|
32d12bb334 | ||
|
|
1b2a02cb1a | ||
|
|
2ff11a16c4 | ||
|
|
441af82dbd | ||
|
|
e09c09af6f | ||
|
|
3721fe226f | ||
|
|
8ace0e11cf | ||
|
|
5e249b0b59 | ||
|
|
4889955ecf | ||
|
|
d840fd53da | ||
|
|
a61819cdb3 | ||
|
|
e986fbb5fb | ||
|
|
8f4d575ec8 | ||
|
|
605a06317b | ||
|
|
a7304ccf47 | ||
|
|
374e2bd4b9 | ||
|
|
09a3246ddb | ||
|
|
a615603866 | ||
|
|
1ca05808e1 | ||
|
|
5febc2a805 | ||
|
|
3c047bee58 | ||
|
|
022c6c157a | ||
|
|
fa587d5678 | ||
|
|
afa5a42f5a | ||
|
|
71df8ba3e2 | ||
|
|
8764998e8c | ||
|
|
2cb4f3aac8 | ||
|
|
1ccaf33aac | ||
|
|
cb0a8e0413 | ||
|
|
8674168df4 | ||
|
|
2221653801 | ||
|
|
78bcdcef5d | ||
|
|
672fbe2ac0 | ||
|
|
56a5970b44 | ||
|
|
a66cef7cfe | ||
|
|
c0b1c2e099 | ||
|
|
9e553bb87b | ||
|
|
f966514bc7 | ||
|
|
dc0a49f96d | ||
|
|
65c783c024 | ||
|
|
6395836fbb | ||
|
|
a7207084ef | ||
|
|
27ef1f1e71 | ||
|
|
68fdb14cd6 | ||
|
|
c2af282a85 | ||
|
|
92d48335cb | ||
|
|
78cac2edc2 | ||
|
|
26d105c439 | ||
|
|
7fec107b98 | ||
|
|
eb01ad3af9 | ||
|
|
e0d9880b32 | ||
|
|
e81e96f0ab | ||
|
|
06d5bd259c | ||
|
|
14238b8d62 | ||
|
|
3b51886927 | ||
|
|
a295ff2e06 |
1
.github/FUNDING.yml
vendored
1
.github/FUNDING.yml
vendored
@@ -1,4 +1,5 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
patreon: PixelPawsAI
|
||||
ko_fi: pixelpawsai
|
||||
custom: ['paypal.me/pixelpawsai']
|
||||
|
||||
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Always use English for comments.
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,6 +1,10 @@
|
||||
__pycache__/
|
||||
settings.json
|
||||
path_mappings.yaml
|
||||
output/*
|
||||
py/run_test.py
|
||||
.vscode/
|
||||
cache/
|
||||
civitai/
|
||||
node_modules/
|
||||
coverage/
|
||||
|
||||
19
AGENTS.md
Normal file
19
AGENTS.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
ComfyUI LoRA Manager pairs a Python backend with lightweight browser scripts. Backend modules live in `py/`, organized by responsibility: HTTP entry points under `routes/`, feature logic in `services/`, reusable helpers within `utils/`, and custom nodes in `nodes/`. Front-end widgets that extend the ComfyUI interface sit in `web/comfyui/`, while static images and templates are in `static/` and `templates/`. Shared localization files are stored in `locales/`, with workflow examples under `example_workflows/`. Tests currently reside alongside the source (`test_i18n.py`) until a dedicated `tests/` folder is introduced.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Install dependencies with `pip install -r requirements.txt` from the repo root. Launch the standalone server for iterative work via `python standalone.py --port 8188`; ComfyUI users can also load the extension directly through ComfyUI's custom node manager. Run backend checks with `python -m pytest test_i18n.py`, and target new test files explicitly (e.g. `python -m pytest tests/test_recipes.py` once added). Use `python scripts/sync_translation_keys.py` to reconcile locale keys after updating UI strings.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Follow PEP 8 with four-space indentation and descriptive snake_case module/function names, mirroring files such as `py/services/settings_manager.py`. Classes remain PascalCase, constants UPPER_SNAKE_CASE, and loggers retrieved via `logging.getLogger(__name__)`. Prefer explicit type hints for new public APIs and docstrings that clarify side effects. JavaScript in `web/comfyui/` is modern ES modules; keep imports relative, favor camelCase functions, and mirror existing file suffixes like `_widget.js` for UI components.
|
||||
|
||||
## Testing Guidelines
|
||||
Extend pytest coverage by co-locating tests near the code under test or in `tests/` with names like `test_<feature>.py`. When introducing new routes or services, add regression cases that mock ComfyUI dependencies (see the standalone mocking helpers in `standalone.py`). Prioritize deterministic fixtures for filesystem interactions and ensure translations include coverage when adding new locale keys. Always run `python -m pytest` before submitting work.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Commits follow the conventional pattern seen in `git log` (`feat(scope):`, `fix(scope):`, `chore(scope):`). Keep messages imperative and scoped to a single change. Pull requests should summarize the problem, detail the solution, list manual test evidence, and link any GitHub issues. Include UI screenshots or GIFs when front-end behavior changes, and call out migration steps (e.g., settings updates) in the PR description.
|
||||
|
||||
## Configuration & Localization Tips
|
||||
Sample configuration defaults live in `settings.json.example`; copy it to `settings.json` and adjust model directories before running the standalone server. Whenever you add UI text, update `locales/<lang>.json` and run the translation sync script. Store reference assets in `civitai/` or `docs/` rather than mixing them with production templates, keeping the runtime folders (`static/`, `templates/`) deploy-ready.
|
||||
121
README.md
121
README.md
@@ -10,71 +10,77 @@ A comprehensive toolset that streamlines organizing, downloading, and applying L
|
||||
|
||||

|
||||
|
||||
One-click Integration:
|
||||

|
||||
|
||||
## 📺 Tutorial: One-Click LoRA Integration
|
||||
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
|
||||
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
|
||||
## 🌐 Browser Extension
|
||||
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
|
||||
|
||||

|
||||
|
||||
<div>
|
||||
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
|
||||
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div>
|
||||
|
||||
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
|
||||
|
||||
---
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.8.17
|
||||
* **Duplicate Model Detection** - Added "Find Duplicates" functionality for LoRAs and checkpoints using model file hash detection, enabling convenient viewing and batch deletion of duplicate models
|
||||
* **Enhanced URL Recipe Imports** - Optimized import recipe via URL functionality using CivitAI API calls instead of web scraping, now supporting all rated images (including NSFW) for recipe imports
|
||||
* **Improved TriggerWord Control** - Enhanced TriggerWord Toggle node with new default_active switch to set the initial state (active/inactive) when trigger words are added
|
||||
* **Centralized Example Management** - Added "Migrate Existing Example Images" feature to consolidate downloaded example images from model folders into central storage with customizable naming patterns
|
||||
* **Intelligent Word Suggestions** - Implemented smart trigger word suggestions by reading class tokens and tag frequency from safetensors files, displaying recommendations when editing trigger words
|
||||
* **Model Version Management** - Added "Re-link to CivitAI" context menu option for connecting models to different CivitAI versions when needed
|
||||
### v0.9.3
|
||||
* **Metadata Archive Database Support** - Added the ability to download and utilize a metadata archive database, enabling access to metadata for models that have been deleted from CivitAI.
|
||||
* **App-Level Proxy Settings** - Introduced support for configuring a global proxy within the application, making it easier to use the manager behind network restrictions.
|
||||
* **Bug Fixes** - Various bug fixes for improved stability and reliability.
|
||||
|
||||
### v0.8.16
|
||||
* **Dramatic Startup Speed Improvement** - Added cache serialization mechanism for significantly faster loading times, especially beneficial for large model collections
|
||||
* **Enhanced Refresh Options** - Extended functionality with "Full Rebuild (complete)" option alongside "Quick Refresh (incremental)" to fix potential memory cache issues without requiring application restart
|
||||
* **Customizable Display Density** - Replaced compact mode with adjustable display density settings for personalized layout customization
|
||||
* **Model Creator Information** - Added creator details to model information panels for better attribution
|
||||
* **Improved WebP Support** - Enhanced Save Image node with workflow embedding capability for WebP format images
|
||||
* **Direct Example Access** - Added "Open Example Images Folder" button to card interfaces for convenient browsing of downloaded model examples
|
||||
* **Enhanced Compatibility** - Full ComfyUI Desktop support for "Send lora or recipe to workflow" functionality
|
||||
* **Cache Management** - Added settings to clear existing cache files when needed
|
||||
* **Bug Fixes & Stability** - Various improvements for overall reliability and performance
|
||||
### v0.9.2
|
||||
* **Bulk Auto-Organization Action** - Added a new bulk auto-organization feature. You can now select multiple models and automatically organize them according to your current path template settings for streamlined management.
|
||||
* **Bug Fixes** - Addressed several bugs to improve stability and reliability.
|
||||
|
||||
### v0.8.15
|
||||
* **Enhanced One-Click Integration** - Replaced copy button with direct send button allowing LoRAs/recipes to be sent directly to your current ComfyUI workflow without needing to paste
|
||||
* **Flexible Workflow Integration** - Click to append LoRAs/recipes to existing loader nodes or Shift+click to replace content, with additional right-click menu options for "Send to Workflow (Append)" or "Send to Workflow (Replace)"
|
||||
* **Improved LoRA Loader Controls** - Added header drag functionality for proportional strength adjustment of all LoRAs simultaneously (including CLIP strengths when expanded)
|
||||
* **Keyboard Navigation Support** - Implemented Page Up/Down for page scrolling, Home key to jump to top, and End key to jump to bottom for faster browsing through large collections
|
||||
### v0.9.1
|
||||
* **Enhanced Bulk Operations** - Improved bulk operations with Marquee Selection and a bulk operation context menu, providing a more intuitive, desktop-application-like user experience.
|
||||
* **New Bulk Actions** - Added bulk operations for adding tags and setting base models to multiple models simultaneously.
|
||||
|
||||
### v0.8.14
|
||||
* **Virtualized Scrolling** - Completely rebuilt rendering mechanism for smooth browsing with no lag or freezing, now supporting virtually unlimited model collections with optimized layouts for large displays, improving space utilization and user experience
|
||||
* **Compact Display Mode** - Added space-efficient view option that displays more cards per row (7 on 1080p, 8 on 2K, 10 on 4K)
|
||||
* **Enhanced LoRA Node Functionality** - Comprehensive improvements to LoRA loader/stacker nodes including real-time trigger word updates (reflecting any change anywhere in the LoRA chain for precise updates) and expanded context menu with "Copy Notes" and "Copy Trigger Words" options for faster workflow
|
||||
### v0.9.0
|
||||
* **UI Overhaul for Enhanced Navigation** - Replaced the top flat folder tags with a new folder sidebar and breadcrumb navigation system for a more intuitive folder browsing and selection experience.
|
||||
* **Dual-Mode Folder Sidebar** - The new folder sidebar offers two display modes: 'List Mode,' which mirrors the classic folder view, and 'Tree Mode,' which presents a hierarchical folder structure for effortless navigation through nested directories.
|
||||
* **Internationalization Support** - Introduced multi-language support, now available in English, Simplified Chinese, Traditional Chinese, Spanish, Japanese, Korean, French, Russian, and German. Feedback from native speakers is welcome to improve the translations.
|
||||
* **Automatic Filename Conflict Resolution** - Implemented automatic file renaming (`original name + short hash`) to prevent conflicts when downloading or moving models.
|
||||
* **Performance Optimizations & Bug Fixes** - Various performance improvements and bug fixes for a more stable and responsive experience.
|
||||
|
||||
### v0.8.13
|
||||
* **Enhanced Recipe Management** - Added "Find duplicates" feature to identify and batch delete duplicate recipes with duplicate detection notifications during imports
|
||||
* **Improved Source Tracking** - Source URLs are now saved with recipes imported via URL, allowing users to view original content with one click or manually edit links
|
||||
* **Advanced LoRA Control** - Double-click LoRAs in Loader/Stacker nodes to access expanded CLIP strength controls for more precise adjustments of model and CLIP strength separately
|
||||
* **Lycoris Model Support** - Added compatibility with Lycoris models for expanded creative options
|
||||
* **Bug Fixes & UX Improvements** - Resolved various issues and enhanced overall user experience with numerous optimizations
|
||||
### v0.8.30
|
||||
* **Automatic Model Path Correction** - Added auto-correction for model paths in built-in nodes such as Load Checkpoint, Load Diffusion Model, Load LoRA, and other custom nodes with similar functionality. Workflows containing outdated or incorrect model paths will now be automatically updated to reflect the current location of your models.
|
||||
* **Node UI Enhancements** - Improved node interface for a smoother and more intuitive user experience.
|
||||
* **Bug Fixes** - Addressed various bugs to enhance stability and reliability.
|
||||
|
||||
### v0.8.12
|
||||
* **Enhanced Model Discovery** - Added alphabetical navigation bar to LoRAs page for faster browsing through large collections
|
||||
* **Optimized Example Images** - Improved download logic to automatically refresh stale metadata before fetching example images
|
||||
* **Model Exclusion System** - New right-click option to exclude specific LoRAs or checkpoints from management
|
||||
* **Improved Showcase Experience** - Enhanced interaction in LoRA and checkpoint showcase areas for better usability
|
||||
### v0.8.29
|
||||
* **Enhanced Recipe Imports** - Improved recipe importing with new target folder selection, featuring path input autocomplete and interactive folder tree navigation. Added a "Use Default Path" option when downloading missing LoRAs.
|
||||
* **WanVideo Lora Select Node Update** - Updated the WanVideo Lora Select node with a 'merge_loras' option to match the counterpart node in the WanVideoWrapper node package.
|
||||
* **Autocomplete Conflict Resolution** - Resolved an autocomplete feature conflict in LoRA nodes with pysssss autocomplete.
|
||||
* **Improved Download Functionality** - Enhanced download functionality with resumable downloads and improved error handling.
|
||||
* **Bug Fixes** - Addressed several bugs for improved stability and performance.
|
||||
|
||||
### v0.8.11
|
||||
* **Offline Image Support** - Added functionality to download and save all model example images locally, ensuring access even when offline or if images are removed from CivitAI or the site is down
|
||||
* **Resilient Download System** - Implemented pause/resume capability with checkpoint recovery that persists through restarts or unexpected exits
|
||||
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
|
||||
### v0.8.28
|
||||
* **Autocomplete for Node Inputs** - Instantly find and add LoRAs by filename directly in Lora Loader, Lora Stacker, and WanVideo Lora Select nodes. Autocomplete suggestions include preview tooltips and preset weights, allowing you to quickly select LoRAs without opening the LoRA Manager UI.
|
||||
* **Duplicate Notification Control** - Added a switch to duplicates mode, enabling users to turn off duplicate model notifications for a more streamlined experience.
|
||||
* **Download Example Images from Context Menu** - Introduced a new context menu option to download example images for individual models.
|
||||
|
||||
### v0.8.10
|
||||
* **Standalone Mode** - Run LoRA Manager independently from ComfyUI for a lightweight experience that works even with other stable diffusion interfaces
|
||||
* **Portable Edition** - New one-click portable version for easy startup and updates in standalone mode
|
||||
* **Enhanced Metadata Collection** - Added support for SamplerCustomAdvanced node in the metadata collector module
|
||||
* **Improved UI Organization** - Optimized Lora Loader node height to display up to 5 LoRAs at once with scrolling capability for larger collections
|
||||
### v0.8.27
|
||||
* **User Experience Enhancements** - Improved the model download target folder selection with path input autocomplete and interactive folder tree navigation, making it easier and faster to choose where models are saved.
|
||||
* **Default Path Option for Downloads** - Added a "Use Default Path" option when downloading models. When enabled, models are automatically organized and stored according to your configured path template settings.
|
||||
* **Advanced Download Path Templates** - Expanded path template settings, allowing users to set individual templates for LoRA, checkpoint, and embedding models for greater flexibility. Introduced the `{author}` placeholder, enabling automatic organization of model files by creator name.
|
||||
* **Bug Fixes & Stability Improvements** - Addressed various bugs and improved overall stability for a smoother experience.
|
||||
|
||||
### v0.8.26
|
||||
* **Creator Search Option** - Added ability to search models by creator name, making it easier to find models from specific authors.
|
||||
* **Enhanced Node Usability** - Improved user experience for Lora Loader, Lora Stacker, and WanVideo Lora Select nodes by fixing the maximum height of the text input area. Users can now freely and conveniently adjust the LoRA region within these nodes.
|
||||
* **Compatibility Fixes** - Resolved compatibility issues with ComfyUI and certain custom nodes, including ComfyUI-Custom-Scripts, ensuring smoother integration and operation.
|
||||
|
||||
[View Update History](./update_logs.md)
|
||||
|
||||
@@ -93,13 +99,6 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
- 🚀 **High Performance**
|
||||
- Fast model loading and browsing
|
||||
- Smooth scrolling through large collections
|
||||
- Real-time updates when files change
|
||||
|
||||
- 📂 **Advanced Organization**
|
||||
- Quick search with fuzzy matching
|
||||
- Folder-based categorization
|
||||
- Move LoRAs between folders
|
||||
- Sort by name or date
|
||||
|
||||
- 🌐 **Rich Model Integration**
|
||||
- Direct download from CivitAI
|
||||
@@ -140,10 +139,11 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
|
||||
### Option 2: **Portable Standalone Edition** (No ComfyUI required)
|
||||
|
||||
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.8.15/lora_manager_portable.7z)
|
||||
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.9.2/lora_manager_portable.7z)
|
||||
2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder
|
||||
3. Edit `settings.json` to include your correct model folder paths and CivitAI API key
|
||||
4. Run run.bat
|
||||
- To change the startup port, edit `run.bat` and modify the parameter (e.g. `--port 9001`)
|
||||
|
||||
### Option 3: **Manual Installation**
|
||||
|
||||
@@ -263,6 +263,8 @@ If you find this project helpful, consider supporting its development:
|
||||
|
||||
[](https://ko-fi.com/pixelpawsai)
|
||||
|
||||
[](https://patreon.com/PixelPawsAI)
|
||||
|
||||
WeChat: [Click to view QR code](https://raw.githubusercontent.com/willmiao/ComfyUI-Lora-Manager/main/static/images/wechat-qr.webp)
|
||||
|
||||
## 💬 Community
|
||||
@@ -271,3 +273,6 @@ Join our Discord community for support, discussions, and updates:
|
||||
[Discord Server](https://discord.gg/vcqNrWVFvM)
|
||||
|
||||
---
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#willmiao/ComfyUI-Lora-Manager&Date)
|
||||
|
||||
42
__init__.py
42
__init__.py
@@ -1,18 +1,42 @@
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
from .py.nodes.debug_metadata import DebugMetadata
|
||||
# Import metadata collector to install hooks on startup
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
try: # pragma: no cover - import fallback for pytest collection
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
from .py.nodes.debug_metadata import DebugMetadata
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
except ImportError: # pragma: no cover - allows running under pytest without package install
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
package_root = pathlib.Path(__file__).resolve().parent
|
||||
if str(package_root) not in sys.path:
|
||||
sys.path.append(str(package_root))
|
||||
|
||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
|
||||
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
|
||||
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
|
||||
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
||||
SaveImage = importlib.import_module("py.nodes.save_image").SaveImage
|
||||
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
||||
WanVideoLoraSelect = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelect
|
||||
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
|
||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
LoraManagerLoader.NAME: LoraManagerLoader,
|
||||
LoraManagerTextLoader.NAME: LoraManagerTextLoader,
|
||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||
LoraStacker.NAME: LoraStacker,
|
||||
SaveImage.NAME: SaveImage,
|
||||
DebugMetadata.NAME: DebugMetadata
|
||||
DebugMetadata.NAME: DebugMetadata,
|
||||
WanVideoLoraSelect.NAME: WanVideoLoraSelect,
|
||||
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
|
||||
180
docs/LM-Extension-Wiki.md
Normal file
180
docs/LM-Extension-Wiki.md
Normal file
@@ -0,0 +1,180 @@
|
||||
## Overview
|
||||
|
||||
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com).
|
||||
It also supports browsing on [CivArchive](https://civarchive.com/) (formerly CivitaiArchive).
|
||||
|
||||
With this extension, you can:
|
||||
|
||||
✅ Instantly see which models are already present in your local library
|
||||
✅ Download new models with a single click
|
||||
✅ Manage downloads efficiently with queue and parallel download support
|
||||
✅ Keep your downloaded models automatically organized according to your custom settings
|
||||
|
||||

|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why Are All Features for Supporters Only?
|
||||
|
||||
I love building tools for the Stable Diffusion and ComfyUI communities, and LoRA Manager is a passion project that I've poured countless hours into. When I created this companion extension, my hope was to offer its core features for free, as a thank-you to all of you.
|
||||
|
||||
Unfortunately, I've reached a point where I need to be realistic. The level of support from the free model has been far lower than what's needed to justify the continuous development and maintenance for both projects. It was a difficult decision, but I've chosen to make the extension's features exclusive to supporters.
|
||||
|
||||
This change is crucial for me to be able to continue dedicating my time to improving the free and open-source LoRA Manager, which I'm committed to keeping available for everyone.
|
||||
|
||||
Your support does more than just unlock a few features—it allows me to keep innovating and ensures the core LoRA Manager project thrives. I'm incredibly grateful for your understanding and any support you can offer. ❤️
|
||||
|
||||
(_For those who previously supported me on Ko-fi with a one-time donation, I'll be sending out license keys individually as a thank-you._)
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
### Supported Browsers & Installation Methods
|
||||
|
||||
| Browser | Installation Method |
|
||||
|--------------------|-------------------------------------------------------------------------------------|
|
||||
| **Google Chrome** | [Chrome Web Store link](https://chromewebstore.google.com/detail/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) |
|
||||
| **Microsoft Edge** | Install via Chrome Web Store (compatible) |
|
||||
| **Brave Browser** | Install via Chrome Web Store (compatible) |
|
||||
| **Opera** | Install via Chrome Web Store (compatible) |
|
||||
| **Firefox** | <div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div> |
|
||||
|
||||
For non-Chrome browsers (e.g., Microsoft Edge), you can typically install extensions from the Chrome Web Store by following these steps: open the extension’s Chrome Web Store page, click 'Get extension', then click 'Allow' when prompted to enable installations from other stores, and finally click 'Add extension' to complete the installation.
|
||||
|
||||
---
|
||||
|
||||
## Privacy & Security
|
||||
|
||||
I understand concerns around browser extensions and privacy, and I want to be fully transparent about how the **LM Civitai Extension** works:
|
||||
|
||||
- **Reviewed and Verified**
|
||||
This extension has been **manually reviewed and approved by the Chrome Web Store**. The Firefox version uses the **exact same code** (only the packaging format differs) and has passed **Mozilla’s Add-on review**.
|
||||
|
||||
- **Minimal Network Access**
|
||||
The only external server this extension connects to is:
|
||||
**`https://willmiao.shop`** — used solely for **license validation**.
|
||||
|
||||
It does **not collect, transmit, or store any personal or usage data**.
|
||||
No browsing history, no user IDs, no analytics, no hidden trackers.
|
||||
|
||||
- **Local-Only Model Detection**
|
||||
Model detection and LoRA Manager communication all happen **locally** within your browser, directly interacting with your local LoRA Manager backend.
|
||||
|
||||
I value your trust and are committed to keeping your local setup private and secure. If you have any questions, feel free to reach out!
|
||||
|
||||
---
|
||||
|
||||
## How to Use
|
||||
|
||||
After installing the extension, you'll automatically receive a **7-day trial** to explore all features.
|
||||
|
||||
When the extension is correctly installed and your license is valid:
|
||||
|
||||
- Open **Civitai**, and you'll see visual indicators added by the extension on model cards, showing:
|
||||
- ✅ Models already present in your local library
|
||||
- ⬇️ A download button for models not in your library
|
||||
|
||||
Clicking the download button adds the corresponding model version to the download queue, waiting to be downloaded. You can set up to **5 models to download simultaneously**.
|
||||
|
||||
### Visual Indicators Appear On:
|
||||
|
||||
- **Home Page** — Featured models
|
||||
- **Models Page**
|
||||
- **Creator Profiles** — If the creator has set their models to be visible
|
||||
- **Recommended Resources** — On individual model pages
|
||||
|
||||
### Version Buttons on Model Pages
|
||||
|
||||
On a specific model page, visual indicators also appear on version buttons, showing which versions are already in your local library.
|
||||
|
||||
When switching to a specific version by clicking a version button:
|
||||
|
||||
- Clicking the download button will open a dropdown:
|
||||
- Download via **LoRA Manager**
|
||||
- Download via **Original Download** (browser download)
|
||||
|
||||
You can check **Remember my choice** to set your preferred default. You can change this setting anytime in the extension's settings.
|
||||
|
||||

|
||||
|
||||
### Resources on Image Pages (2025-08-05) — now shows in-library indicators for image resources. ‘Import image as recipe’ coming soon!
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Model Download Location & LoRA Manager Settings
|
||||
|
||||
To use the **one-click download function**, you must first set:
|
||||
|
||||
- Your **Default LoRAs Root**
|
||||
- Your **Default Checkpoints Root**
|
||||
|
||||
These are set within LoRA Manager's settings.
|
||||
|
||||
When everything is configured, downloaded model files will be placed in:
|
||||
|
||||
`<Default_Models_Root>/<Base_Model_of_the_Model>/<First_Tag_of_the_Model>`
|
||||
|
||||
|
||||
### Update: Default Path Customization (2025-07-21)
|
||||
|
||||
A new setting to customize the default download path has been added in the nightly version. You can now personalize where models are saved when downloading via the LM Civitai Extension.
|
||||
|
||||

|
||||
|
||||
The previous YAML path mapping file will be deprecated—settings will now be unified in settings.json to simplify configuration.
|
||||
|
||||
---
|
||||
|
||||
## Backend Port Configuration
|
||||
|
||||
If your **ComfyUI** or **LoRA Manager** backend is running on a port **other than the default 8188**, you must configure the backend port in the extension's settings.
|
||||
|
||||
After correctly setting and saving the port, you'll see in the extension's header area:
|
||||
- A **Healthy** status with the tooltip: `Connected to LoRA Manager on port xxxx`
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Connecting to a Remote LoRA Manager
|
||||
|
||||
If your LoRA Manager is running on another computer, you can still connect from your browser using port forwarding.
|
||||
|
||||
> **Why can't you set a remote IP directly?**
|
||||
>
|
||||
> For privacy and security, the extension only requests access to `http://127.0.0.1/*`. Supporting remote IPs would require much broader permissions, which may be rejected by browser stores and could raise user concerns.
|
||||
|
||||
**Solution: Port Forwarding with `socat`**
|
||||
|
||||
On your browser computer, run:
|
||||
|
||||
`socat TCP-LISTEN:8188,bind=127.0.0.1,fork TCP:REMOTE.IP.ADDRESS.HERE:8188`
|
||||
|
||||
- Replace `REMOTE.IP.ADDRESS.HERE` with the IP of the machine running LoRA Manager.
|
||||
- Adjust the port if needed.
|
||||
|
||||
This lets the extension connect to `127.0.0.1:8188` as usual, with traffic forwarded to your remote server.
|
||||
|
||||
_Thanks to user **Temikus** for sharing this solution!_
|
||||
|
||||
---
|
||||
|
||||
## Roadmap
|
||||
|
||||
The extension will evolve alongside **LoRA Manager** improvements. Planned features include:
|
||||
|
||||
- [x] Support for **additional model types** (e.g., embeddings)
|
||||
- [ ] One-click **Recipe Import**
|
||||
- [x] Display of in-library status for all resources in the **Resources Used** section of the image page
|
||||
- [x] One-click **Auto-organize Models**
|
||||
|
||||
**Stay tuned — and thank you for your support!**
|
||||
|
||||
---
|
||||
|
||||
93
docs/architecture/example_images_routes.md
Normal file
93
docs/architecture/example_images_routes.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# Example image route architecture
|
||||
|
||||
The example image routing stack mirrors the layered model route stack described in
|
||||
[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup,
|
||||
handler orchestration, and long-running workflows now live in clearly separated modules so
|
||||
we can extend download/import behaviour without touching the entire feature surface.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ExampleImagesHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Download manager / processor / file manager]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Filesystem]
|
||||
F1 --> G2[Model metadata]
|
||||
F1 --> G3[WebSocket progress]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. |
|
||||
| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. |
|
||||
| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. |
|
||||
| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. |
|
||||
| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}`
|
||||
mapping consumed by the registrar. The table below outlines how each handler collaborates
|
||||
with the use cases and utilities.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. |
|
||||
| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. |
|
||||
| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. |
|
||||
| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Shared progress snapshots** - The download handler returns the same snapshot built by
|
||||
`DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket
|
||||
progress events.
|
||||
* **Safe filesystem access** - All folder/file actions flow through
|
||||
`ExampleImagesFileManager`, which validates the configured example image root and ensures
|
||||
responses never leak absolute paths outside the allowed directory.
|
||||
* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`,
|
||||
which updates model metadata via `MetadataManager` and notifies the relevant scanners so
|
||||
cache state stays in sync.
|
||||
|
||||
## Migration notes
|
||||
|
||||
The refactor brings the example image stack in line with the model/recipe stacks:
|
||||
|
||||
1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects
|
||||
should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring
|
||||
handler callables.
|
||||
2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like
|
||||
`BaseModelRoutes`. If you previously instantiated handlers directly, inject custom
|
||||
collaborators via the controller constructor (`download_manager`, `processor`,
|
||||
`file_manager`) to keep test seams predictable.
|
||||
3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching
|
||||
`DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers
|
||||
expect those abstractions to surface validation/concurrency errors, and bypassing them
|
||||
will skip the HTTP-friendly error mapping.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`.
|
||||
2. Expose the coroutine on an existing handler class (or create a new handler and extend
|
||||
`ExampleImagesHandlerSet`).
|
||||
3. Wire additional services or factories inside `_build_handler_set` on
|
||||
`ExampleImagesRoutes`, mirroring how the model stack introduces new use cases.
|
||||
|
||||
`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause
|
||||
flows, and import validations. Use it as a template when introducing new handler
|
||||
collaborators or error mappings.
|
||||
100
docs/architecture/model_routes.md
Normal file
100
docs/architecture/model_routes.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# Base model route architecture
|
||||
|
||||
The model routing stack now splits HTTP wiring, orchestration logic, and
|
||||
business rules into discrete layers. The goal is to make it obvious where a
|
||||
new collaborator should live and which contract it must honour. The diagram
|
||||
below captures the end-to-end flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ModelHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & metadata]
|
||||
F1 --> G2[Filesystem]
|
||||
F1 --> G3[WebSocket state]
|
||||
end
|
||||
```
|
||||
|
||||
Every box maps to a concrete module:
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. |
|
||||
| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. |
|
||||
| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). |
|
||||
| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. |
|
||||
| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. |
|
||||
|
||||
## Handler responsibilities & contracts
|
||||
|
||||
`ModelHandlerSet` flattens the handler objects into the exact callables used by
|
||||
the registrar. The table below highlights the separation of concerns within
|
||||
the set and the invariants that must hold after each handler returns.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
||||
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
||||
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
||||
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
||||
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
||||
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
||||
| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. |
|
||||
| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
Each use case exposes a narrow asynchronous API that hides the underlying
|
||||
services. Their error mapping is essential for predictable HTTP responses.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. |
|
||||
| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. |
|
||||
| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. |
|
||||
|
||||
## Maintaining legacy contracts
|
||||
|
||||
The refactor preserves the invariants called out in the previous architecture
|
||||
notes. The most critical ones are reiterated here to emphasise the
|
||||
collaboration points:
|
||||
|
||||
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
||||
channelled through `ModelManagementHandler`. The handler delegates to
|
||||
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
|
||||
mutated in-place before the handler returns. The accompanying tests assert
|
||||
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
|
||||
each mutation.
|
||||
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
||||
asset, `MetadataSyncService` persists the JSON metadata, and
|
||||
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
||||
the static URL produced by `config.get_preview_static_url`, keeping browser
|
||||
clients in lockstep with disk state.
|
||||
3. **Download progress** – `DownloadCoordinator.schedule_download` generates the
|
||||
download identifier, registers a WebSocket progress callback, and caches the
|
||||
latest numeric progress via `WebSocketManager`. Both `download_model`
|
||||
responses and `/download-progress/{id}` polling read from the same cache to
|
||||
guarantee consistent progress reporting across transports.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
To add a new shared route:
|
||||
|
||||
1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name.
|
||||
2. Implement the corresponding coroutine on one of the handlers inside
|
||||
`ModelHandlerSet` (or introduce a new handler class when the concern does not
|
||||
fit existing ones).
|
||||
3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by
|
||||
wiring services or use cases through the constructor parameters.
|
||||
|
||||
Model-specific routes should continue to be registered inside the subclass
|
||||
implementation of `setup_specific_routes`, reusing the shared registrar where
|
||||
possible.
|
||||
89
docs/architecture/recipe_routes.md
Normal file
89
docs/architecture/recipe_routes.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Recipe route architecture
|
||||
|
||||
The recipe routing stack now mirrors the modular model route design. HTTP
|
||||
bindings, controller wiring, handler orchestration, and business rules live in
|
||||
separate layers so new behaviours can be added without re-threading the entire
|
||||
feature. The diagram below outlines the flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[RecipeHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & fingerprint index]
|
||||
F1 --> G2[Metadata files]
|
||||
F1 --> G3[Temporary shares]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. |
|
||||
| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. |
|
||||
| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. |
|
||||
| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`RecipeHandlerSet` flattens purpose-built handler objects into the callables the
|
||||
registrar binds. Each handler is responsible for a narrow concern and enforces a
|
||||
set of invariants before returning:
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. |
|
||||
| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. |
|
||||
| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. |
|
||||
| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. |
|
||||
| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. |
|
||||
| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
The dedicated services encapsulate long-running work so handlers stay thin.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. |
|
||||
| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. |
|
||||
| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call
|
||||
back into the recipe scanner to mutate the in-memory cache and fingerprint
|
||||
index before returning a response. Tests assert that these methods are invoked
|
||||
even when stubbing persistence.
|
||||
* **Fingerprint management** – `RecipePersistenceService` recomputes
|
||||
fingerprints whenever LoRA metadata changes and duplicate lookups use those
|
||||
fingerprints to group recipes. Handlers bubble the resulting IDs so clients
|
||||
can merge duplicates without an extra fetch.
|
||||
* **Metadata synchronisation** – Saving or reconnecting a recipe updates the
|
||||
JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the
|
||||
scanner to resort its cache. Sharing relies on this metadata to generate
|
||||
filenames and ensure downloads stay in sync with on-disk state.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name.
|
||||
2. Implement the coroutine on an existing handler or introduce a new handler
|
||||
class inside `py/routes/handlers/recipe_handlers.py` when the concern does
|
||||
not fit existing ones.
|
||||
3. Wire additional collaborators inside
|
||||
`BaseRecipeRoutes._create_handler_set` (inject new services or factories) and
|
||||
expose helper getters on the handler owner if the handler needs to share
|
||||
utilities.
|
||||
|
||||
Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing,
|
||||
mutation, analysis-error, and sharing paths end-to-end, ensuring the controller
|
||||
and handler wiring remains valid as new capabilities are added.
|
||||
|
||||
23
docs/frontend-testing-roadmap.md
Normal file
23
docs/frontend-testing-roadmap.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Frontend Automation Testing Roadmap
|
||||
|
||||
This roadmap tracks the planned rollout of automated testing for the ComfyUI LoRA Manager frontend. Each phase builds on the infrastructure introduced in this change set and records progress so future contributors can quickly identify the next tasks.
|
||||
|
||||
## Phase Overview
|
||||
|
||||
| Phase | Goal | Primary Focus | Status | Notes |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| Phase 0 | Establish baseline tooling | Add Node test runner, jsdom environment, and seed smoke tests | ✅ Complete | Vitest + jsdom configured, example state tests committed |
|
||||
| Phase 1 | Cover state management logic | Unit test selectors, derived data helpers, and storage utilities under `static/js/state` and `static/js/utils` | ✅ Complete | Storage helpers and state selectors now exercised via deterministic suites |
|
||||
| Phase 2 | Test AppCore orchestration | Simulate page bootstrapping, infinite scroll hooks, and manager registration using JSDOM DOM fixtures | 🟡 In Progress | AppCore initialization specs landed; expand to additional page wiring and scroll hooks |
|
||||
| Phase 3 | Validate page-specific managers | Add focused suites for `loras`, `checkpoints`, `embeddings`, and `recipes` managers covering filtering, sorting, and bulk actions | ⚪ Not Started | Consider shared helpers for mocking API modules and storage |
|
||||
| Phase 4 | Interaction-level regression tests | Exercise template fragments, modals, and menus to ensure UI wiring remains intact | ⚪ Not Started | Evaluate Playwright component testing or happy-path DOM snapshots |
|
||||
| Phase 5 | Continuous integration & coverage | Integrate frontend tests into CI workflow and track coverage metrics | ⚪ Not Started | Align reporting directories with backend coverage for unified reporting |
|
||||
|
||||
## Next Steps Checklist
|
||||
|
||||
- [x] Expand unit tests for `storageHelpers` covering migrations and namespace behavior.
|
||||
- [ ] Document DOM fixture strategy for reproducing template structures in tests.
|
||||
- [x] Prototype AppCore initialization test that verifies manager bootstrapping with stubbed dependencies.
|
||||
- [ ] Evaluate integrating coverage reporting once test surface grows (> 20 specs).
|
||||
|
||||
Maintaining this roadmap alongside code changes will make it easier to append new automated test tasks and update their progress.
|
||||
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1242
locales/de.json
Normal file
1242
locales/de.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/en.json
Normal file
1242
locales/en.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/es.json
Normal file
1242
locales/es.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/fr.json
Normal file
1242
locales/fr.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/ja.json
Normal file
1242
locales/ja.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/ko.json
Normal file
1242
locales/ko.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/ru.json
Normal file
1242
locales/ru.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/zh-CN.json
Normal file
1242
locales/zh-CN.json
Normal file
File diff suppressed because it is too large
Load Diff
1242
locales/zh-TW.json
Normal file
1242
locales/zh-TW.json
Normal file
File diff suppressed because it is too large
Load Diff
2572
package-lock.json
generated
Normal file
2572
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
14
package.json
Normal file
14
package.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "comfyui-lora-manager-frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"jsdom": "^24.0.0",
|
||||
"vitest": "^1.6.0"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Project namespace package."""
|
||||
|
||||
# pytest's internal compatibility layer still imports ``py.path.local`` from the
|
||||
# historical ``py`` dependency. Because this project reuses the ``py`` package
|
||||
# name, we expose a minimal shim so ``py.path.local`` resolves to ``pathlib.Path``
|
||||
# during test runs without pulling in the external dependency.
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
path = SimpleNamespace(local=Path)
|
||||
|
||||
__all__ = ["path"]
|
||||
|
||||
214
py/config.py
214
py/config.py
@@ -3,11 +3,11 @@ import platform
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = 'nodes' not in sys.modules
|
||||
# Use an environment variable to control standalone mode
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,13 +17,18 @@ class Config:
|
||||
def __init__(self):
|
||||
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
||||
# 路径映射字典, target to link mapping
|
||||
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
|
||||
# Path mapping dictionary, target to link mapping
|
||||
self._path_mappings = {}
|
||||
# 静态路由映射字典, target to route mapping
|
||||
# Static route mapping dictionary, target to route mapping
|
||||
self._route_mappings = {}
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.checkpoints_roots = self._init_checkpoint_paths()
|
||||
# 在初始化时扫描符号链接
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
self.embeddings_roots = None
|
||||
self.base_models_roots = self._init_checkpoint_paths()
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._scan_symbolic_links()
|
||||
|
||||
if not standalone_mode:
|
||||
@@ -33,34 +38,37 @@ class Config:
|
||||
def save_folder_paths_to_settings(self):
|
||||
"""Save folder paths to settings.json for standalone mode to use later"""
|
||||
try:
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
|
||||
# Get all relevant paths
|
||||
lora_paths = folder_paths.get_folder_paths("loras")
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffuser_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': lora_paths,
|
||||
'checkpoints': checkpoint_paths,
|
||||
'diffusers': diffuser_paths,
|
||||
'unet': unet_paths
|
||||
}
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': self.loras_roots,
|
||||
'checkpoints': self.checkpoints_roots,
|
||||
'unet': self.unet_roots,
|
||||
'embeddings': self.embeddings_roots,
|
||||
}
|
||||
|
||||
# Add default roots if there's only one item and key doesn't exist
|
||||
if len(self.loras_roots) == 1 and "default_lora_root" not in settings:
|
||||
settings["default_lora_root"] = self.loras_roots[0]
|
||||
|
||||
if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings:
|
||||
settings["default_checkpoint_root"] = self.checkpoints_roots[0]
|
||||
|
||||
if self.embeddings_roots and len(self.embeddings_roots) == 1 and "default_embedding_root" not in settings:
|
||||
settings["default_embedding_root"] = self.embeddings_roots[0]
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
@@ -82,15 +90,18 @@ class Config:
|
||||
return False
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""扫描所有 LoRA 和 Checkpoint 根目录中的符号链接"""
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
for root in self.loras_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.checkpoints_roots:
|
||||
for root in self.base_models_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.embeddings_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
def _scan_directory_links(self, root: str):
|
||||
"""递归扫描目录中的符号链接"""
|
||||
"""Recursively scan symbolic links in a directory"""
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
@@ -105,40 +116,40 @@ class Config:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
|
||||
def add_path_mapping(self, link_path: str, target_path: str):
|
||||
"""添加符号链接路径映射
|
||||
target_path: 实际目标路径
|
||||
link_path: 符号链接路径
|
||||
"""Add a symbolic link path mapping
|
||||
target_path: actual target path
|
||||
link_path: symbolic link path
|
||||
"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||
# 保持原有的映射关系:目标路径 -> 链接路径
|
||||
# Keep the original mapping: target path -> link path
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
|
||||
def add_route_mapping(self, path: str, route: str):
|
||||
"""添加静态路由映射"""
|
||||
"""Add a static route mapping"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
self._route_mappings[normalized_path] = route
|
||||
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
|
||||
def map_path_to_link(self, path: str) -> str:
|
||||
"""将目标路径映射回符号链接路径"""
|
||||
"""Map a target path back to its symbolic link path"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_path.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为链接路径
|
||||
# If the path starts with the target path, replace with link path
|
||||
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return path
|
||||
|
||||
def map_link_to_path(self, link_path: str) -> str:
|
||||
"""将符号链接路径映射回实际路径"""
|
||||
"""Map a symbolic link path back to the actual path"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_link.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为实际路径
|
||||
# If the path starts with the target path, replace with actual path
|
||||
mapped_path = normalized_link.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return link_path
|
||||
@@ -177,47 +188,106 @@ class Config:
|
||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||
try:
|
||||
# Get checkpoint paths from folder_paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Combine all checkpoint-related paths
|
||||
all_paths = checkpoint_paths + diffusion_paths + unet_paths
|
||||
# Normalize and resolve symlinks for checkpoints, store mapping from resolved -> original
|
||||
checkpoint_map = {}
|
||||
for path in raw_checkpoint_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
checkpoint_map[real_path] = checkpoint_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Filter and normalize paths
|
||||
paths = sorted(set(path.replace(os.sep, "/")
|
||||
for path in all_paths
|
||||
if os.path.exists(path)), key=lambda p: p.lower())
|
||||
# Normalize and resolve symlinks for unet, store mapping from resolved -> original
|
||||
unet_map = {}
|
||||
for path in raw_unet_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
|
||||
# Merge both maps and deduplicate by real path
|
||||
merged_map = {}
|
||||
for real_path, orig_path in {**checkpoint_map, **unet_map}.items():
|
||||
if real_path not in merged_map:
|
||||
merged_map[real_path] = orig_path
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
|
||||
|
||||
if not paths:
|
||||
# Split back into checkpoints and unet roots for class properties
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_map.values()]
|
||||
|
||||
all_paths = unique_paths
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
|
||||
|
||||
if not all_paths:
|
||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
# 初始化路径映射,与 LoRA 路径处理方式相同
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
# Initialize path mappings
|
||||
for original_path in all_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return paths
|
||||
return all_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing checkpoint paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_embedding_paths(self) -> List[str]:
|
||||
"""Initialize and validate embedding paths from ComfyUI settings"""
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("embeddings")
|
||||
|
||||
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||
path_map = {}
|
||||
for path in raw_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid embeddings folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing embedding paths: {e}")
|
||||
return []
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
"""Convert local preview path to static URL"""
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
real_path = os.path.realpath(preview_path).replace(os.sep, '/')
|
||||
|
||||
|
||||
# Find longest matching path (most specific match)
|
||||
best_match = ""
|
||||
best_route = ""
|
||||
|
||||
for path, route in self._route_mappings.items():
|
||||
if real_path.startswith(path):
|
||||
relative_path = os.path.relpath(real_path, path)
|
||||
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||
|
||||
if real_path.startswith(path) and len(path) > len(best_match):
|
||||
best_match = path
|
||||
best_route = route
|
||||
|
||||
if best_match:
|
||||
relative_path = os.path.relpath(real_path, best_match).replace(os.sep, '/')
|
||||
safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')]
|
||||
safe_path = '/'.join(safe_parts)
|
||||
return f'{best_route}/{safe_path}'
|
||||
|
||||
return ""
|
||||
|
||||
# Global config instance
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,12 +28,28 @@ class LoraManager:
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
"""Initialize and register all routes"""
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
@@ -58,7 +78,7 @@ class LoraManager:
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
@@ -75,127 +95,300 @@ class LoraManager:
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(config.embeddings_roots, start=1):
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for symlink target paths
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
if target_path not in added_targets:
|
||||
# Determine if this is a checkpoint or lora link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
|
||||
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
|
||||
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
|
||||
app.router.add_static(route_path, target_path)
|
||||
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(target_path)
|
||||
|
||||
try:
|
||||
app.router.add_static(route_path, Path(target_path).resolve(strict=False))
|
||||
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(target_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
|
||||
continue
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
app.router.add_static('/locales', config.i18n_path)
|
||||
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
|
||||
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
# Register default model types with the factory
|
||||
register_default_model_types()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
# Setup non-model-specific routes
|
||||
stats_routes = StatsRoutes()
|
||||
stats_routes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app) # Register miscellaneous routes
|
||||
ExampleImagesRoutes.setup_routes(app) # Register example images routes
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||
|
||||
@classmethod
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
# Ensure aiohttp access logger is configured with reduced verbosity
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Get file monitors through ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_lora_monitor()
|
||||
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
|
||||
|
||||
# Start monitors
|
||||
lora_monitor.start()
|
||||
logger.debug("Lora monitor started")
|
||||
|
||||
# Make sure checkpoint monitor has paths before starting
|
||||
await checkpoint_monitor.initialize_paths()
|
||||
checkpoint_monitor.start()
|
||||
logger.debug("Checkpoint monitor started")
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Register DownloadManager with ServiceRegistry
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
await ServiceRegistry.get_download_manager()
|
||||
|
||||
from .services.metadata_service import initialize_metadata_providers
|
||||
await initialize_metadata_providers()
|
||||
|
||||
# Initialize WebSocket manager
|
||||
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||
await ServiceRegistry.get_websocket_manager()
|
||||
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Initialize metadata collector if not in standalone mode
|
||||
if not STANDALONE_MODE:
|
||||
from .metadata_collector import init as init_metadata
|
||||
init_metadata()
|
||||
logger.debug("Metadata collector initialized")
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
init_tasks = [
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
|
||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
]
|
||||
|
||||
await ExampleImagesMigration.check_and_run_migrations()
|
||||
|
||||
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
# Schedule post-initialization tasks to run after scanners complete
|
||||
asyncio.create_task(
|
||||
cls._run_post_initialization_tasks(init_tasks),
|
||||
name='post_init_tasks'
|
||||
)
|
||||
|
||||
logger.debug("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _run_post_initialization_tasks(cls, init_tasks):
|
||||
"""Run post-initialization tasks after all scanners complete"""
|
||||
try:
|
||||
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
|
||||
|
||||
# Wait for all scanner initialization tasks to complete
|
||||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||
|
||||
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
|
||||
|
||||
# Run post-initialization tasks
|
||||
post_tasks = [
|
||||
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
||||
# Add more post-initialization tasks here as needed
|
||||
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
||||
]
|
||||
|
||||
# Run all post-initialization tasks
|
||||
results = await asyncio.gather(*post_tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
for i, result in enumerate(results):
|
||||
task_name = post_tasks[i].get_name()
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
|
||||
else:
|
||||
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
|
||||
|
||||
logger.debug("LoRA Manager: All post-initialization tasks completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files(cls):
|
||||
"""Clean up .bak files in all model roots"""
|
||||
try:
|
||||
logger.debug("Starting cleanup of .bak files in model directories...")
|
||||
|
||||
# Collect all model roots
|
||||
all_roots = set()
|
||||
all_roots.update(config.loras_roots)
|
||||
all_roots.update(config.base_models_roots)
|
||||
all_roots.update(config.embeddings_roots)
|
||||
|
||||
total_deleted = 0
|
||||
total_size_freed = 0
|
||||
|
||||
for root_path in all_roots:
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
|
||||
total_deleted += deleted_count
|
||||
total_size_freed += size_freed
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
|
||||
|
||||
# Yield control periodically
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
|
||||
else:
|
||||
logger.debug("Backup cleanup completed: no .bak files found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files_in_directory(cls, directory_path: str):
|
||||
"""Clean up .bak files in a specific directory recursively
|
||||
|
||||
Args:
|
||||
directory_path: Path to the directory to clean
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (number of files deleted, total size freed in bytes)
|
||||
"""
|
||||
deleted_count = 0
|
||||
size_freed = 0
|
||||
visited_paths = set()
|
||||
|
||||
def cleanup_recursive(path):
|
||||
nonlocal deleted_count, size_freed
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
|
||||
file_size = entry.stat().st_size
|
||||
os.remove(entry.path)
|
||||
deleted_count += 1
|
||||
size_freed += file_size
|
||||
logger.debug(f"Deleted .bak file: {entry.path}")
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
cleanup_recursive(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory {path} for .bak files: {e}")
|
||||
|
||||
# Run the recursive cleanup in a thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, cleanup_recursive, directory_path)
|
||||
|
||||
return deleted_count, size_freed
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_example_images_folders(cls):
|
||||
"""Invoke the example images cleanup service for manual execution."""
|
||||
try:
|
||||
service = ExampleImagesCleanupService()
|
||||
result = await service.cleanup_example_image_folders()
|
||||
|
||||
if result.get('success'):
|
||||
logger.debug(
|
||||
"Manual example images cleanup completed: moved=%s",
|
||||
result.get('moved_total'),
|
||||
)
|
||||
elif result.get('partial_success'):
|
||||
logger.warning(
|
||||
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
|
||||
result.get('moved_total'),
|
||||
result.get('move_failures'),
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Manual example images cleanup skipped or failed: %s",
|
||||
result.get('error', 'no changes'),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e: # pragma: no cover - defensive guard
|
||||
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'error_code': 'unexpected_error',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def _cleanup(cls, app):
|
||||
"""Cleanup resources using ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Cleaning up services")
|
||||
|
||||
# Get monitors from ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
|
||||
if lora_monitor:
|
||||
lora_monitor.stop()
|
||||
logger.info("Stopped LoRA monitor")
|
||||
|
||||
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
|
||||
if checkpoint_monitor:
|
||||
checkpoint_monitor.stop()
|
||||
logger.info("Stopped checkpoint monitor")
|
||||
|
||||
# Close CivitaiClient gracefully
|
||||
civitai_client = await ServiceRegistry.get_service("civitai_client")
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
logger.info("Closed CivitaiClient connection")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = 'nodes' not in sys.modules
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
if not standalone_mode:
|
||||
from .metadata_hook import MetadataHook
|
||||
|
||||
@@ -26,98 +26,179 @@ class MetadataHook:
|
||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
return
|
||||
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
# Detect whether we're using the new async version of ComfyUI
|
||||
is_async = False
|
||||
map_node_func_name = '_map_node_over_list'
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
if hasattr(execution, '_async_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
|
||||
map_node_func_name = '_async_map_node_over_list'
|
||||
elif hasattr(execution, '_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
# Make map_node_over_list public to avoid it being hidden by hooks
|
||||
execution.map_node_over_list = original_map_node_over_list
|
||||
if is_async:
|
||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||
else:
|
||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
MetadataHook._install_sync_hooks(execution)
|
||||
|
||||
print("Metadata collection hooks installed for runtime values")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error installing metadata hooks: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _install_sync_hooks(execution):
|
||||
"""Install hooks for synchronous execution model"""
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
|
||||
@staticmethod
|
||||
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
|
||||
"""Install hooks for asynchronous execution model"""
|
||||
# Store the original _async_map_node_over_list function
|
||||
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||
|
||||
# Wrapped async function, compatible with both stable and nightly
|
||||
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
|
||||
hidden_inputs = kwargs.get('hidden_inputs', None)
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Call original function with all args/kwargs
|
||||
results = await original_map_node_over_list(
|
||||
prompt_id, unique_id, obj, input_data_all, func,
|
||||
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
|
||||
)
|
||||
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return await original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions with async versions
|
||||
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||
execution.execute = async_execute_with_prompt_tracking
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from .constants import IMAGES
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = 'nodes' not in sys.modules
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
|
||||
|
||||
@@ -18,6 +19,10 @@ class MetadataProcessor:
|
||||
- metadata: The workflow metadata
|
||||
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
"""
|
||||
if downstream_id is None:
|
||||
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
||||
downstream_id = metadata[IMAGES]["first_decode"]["node_id"]
|
||||
|
||||
# If we have a downstream_id and execution_order, use it to narrow down potential samplers
|
||||
if downstream_id and "execution_order" in metadata:
|
||||
execution_order = metadata["execution_order"]
|
||||
@@ -209,6 +214,72 @@ class MetadataProcessor:
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def match_conditioning_to_prompts(metadata, sampler_id):
|
||||
"""
|
||||
Match conditioning objects from a sampler to prompts in metadata
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- sampler_id: ID of the sampler node to match
|
||||
|
||||
Returns:
|
||||
- Dictionary with 'prompt' and 'negative_prompt' if found
|
||||
"""
|
||||
result = {
|
||||
"prompt": "",
|
||||
"negative_prompt": ""
|
||||
}
|
||||
|
||||
# Check if we have stored conditioning objects for this sampler
|
||||
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
|
||||
return ""
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def extract_generation_params(metadata, id=None):
|
||||
"""
|
||||
@@ -224,7 +295,7 @@ class MetadataProcessor:
|
||||
"seed": None,
|
||||
"steps": None,
|
||||
"cfg_scale": None,
|
||||
"guidance": None, # Add guidance parameter
|
||||
# "guidance": None, # Add guidance parameter
|
||||
"sampler": None,
|
||||
"scheduler": None,
|
||||
"checkpoint": None,
|
||||
@@ -261,7 +332,6 @@ class MetadataProcessor:
|
||||
params["sampler"] = sampling_params.get("sampler_name")
|
||||
params["scheduler"] = sampling_params.get("scheduler")
|
||||
|
||||
# Trace connections from the primary sampler
|
||||
if prompt and primary_sampler_id:
|
||||
# Check if this is a SamplerCustomAdvanced node
|
||||
is_custom_advanced = False
|
||||
@@ -269,69 +339,43 @@ class MetadataProcessor:
|
||||
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||
|
||||
if is_custom_advanced:
|
||||
# For SamplerCustomAdvanced, trace specific inputs
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
# For SamplerCustomAdvanced, use the new handler method
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
else:
|
||||
# Original tracing for standard samplers
|
||||
# Trace positive prompt - look specifically for CLIPTextEncode
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
else:
|
||||
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
|
||||
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
|
||||
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
|
||||
# For standard samplers, match conditioning objects to prompts
|
||||
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
|
||||
params["prompt"] = prompt_results["prompt"]
|
||||
params["negative_prompt"] = prompt_results["negative_prompt"]
|
||||
|
||||
# If prompts were still not found, fall back to tracing connections
|
||||
if not params["prompt"]:
|
||||
# Original tracing for standard samplers
|
||||
# Trace positive prompt - look specifically for CLIPTextEncode
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
else:
|
||||
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
|
||||
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
|
||||
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
|
||||
|
||||
# Trace negative prompt - look specifically for CLIPTextEncode
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
|
||||
# Trace negative prompt - look specifically for CLIPTextEncode
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
# For SamplerCustom, handle any additional parameters
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
# Size extraction is same for all sampler types
|
||||
# Check if the sampler itself has size information (from latent_image)
|
||||
if primary_sampler_id in metadata.get(SIZE, {}):
|
||||
width = metadata[SIZE][primary_sampler_id].get("width")
|
||||
height = metadata[SIZE][primary_sampler_id].get("height")
|
||||
if width and height:
|
||||
params["size"] = f"{width}x{height}"
|
||||
# Size extraction is same for all sampler types
|
||||
# Check if the sampler itself has size information (from latent_image)
|
||||
if primary_sampler_id in metadata.get(SIZE, {}):
|
||||
width = metadata[SIZE][primary_sampler_id].get("width")
|
||||
height = metadata[SIZE][primary_sampler_id].get("height")
|
||||
if width and height:
|
||||
params["size"] = f"{width}x{height}"
|
||||
|
||||
# Extract LoRAs using the standardized format
|
||||
lora_parts = []
|
||||
@@ -377,3 +421,59 @@ class MetadataProcessor:
|
||||
"""Convert metadata to JSON string"""
|
||||
params = MetadataProcessor.to_dict(metadata, id)
|
||||
return json.dumps(params, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
|
||||
"""
|
||||
Handle parameter extraction for SamplerCustomAdvanced nodes
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- prompt: The prompt object containing node connections
|
||||
- primary_sampler_id: ID of the SamplerCustomAdvanced node
|
||||
- params: Parameters dictionary to update
|
||||
"""
|
||||
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
|
||||
return
|
||||
|
||||
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
|
||||
if "sigmas" in sampler_inputs:
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||
if "sampler" in sampler_inputs:
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
if "guider" in sampler_inputs:
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
@@ -35,7 +35,70 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class TSCCheckpointLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "ckpt_name" not in inputs:
|
||||
return
|
||||
|
||||
model_name = inputs.get("ckpt_name")
|
||||
if model_name:
|
||||
metadata[MODELS][node_id] = {
|
||||
"name": model_name,
|
||||
"type": "checkpoint",
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# For loader node has lora_stack input, like Efficient Loader from Efficient Nodes
|
||||
active_loras = []
|
||||
|
||||
# Process lora_stack if available
|
||||
if "lora_stack" in inputs:
|
||||
lora_stack = inputs.get("lora_stack", [])
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Extract lora name from path (following the format in lora_loader.py)
|
||||
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": model_strength
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
# Extract positive and negative prompt text if available
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
# Store both positive and negative text
|
||||
metadata[PROMPTS][node_id]["positive_text"] = positive_text
|
||||
metadata[PROMPTS][node_id]["negative_text"] = negative_text
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Handle conditioning outputs from TSC_EfficientLoader
|
||||
# outputs is a list with [(model, positive_encoded, negative_encoded, {"samples":latent}, vae, clip, dependencies,)]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 3:
|
||||
positive_conditioning = first_output[1]
|
||||
negative_conditioning = first_output[2]
|
||||
|
||||
# Save both conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = positive_conditioning
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = negative_conditioning
|
||||
|
||||
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -47,53 +110,22 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
"text": text,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class SamplerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
def update(node_id, outputs, metadata):
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@staticmethod
|
||||
def extract_sampling_params(node_id, inputs, metadata, param_keys):
|
||||
"""Extract sampling parameters from inputs"""
|
||||
sampling_params = {}
|
||||
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
|
||||
for key in param_keys:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
@@ -102,7 +134,25 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor):
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_conditioning(node_id, inputs, metadata):
|
||||
"""Extract conditioning objects from inputs"""
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
|
||||
# Save conditioning objects in metadata for later matching
|
||||
if pos_conditioning is not None or neg_conditioning is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
@staticmethod
|
||||
def extract_latent_dimensions(node_id, inputs, metadata):
|
||||
"""Extract dimensions from latent image"""
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
@@ -123,6 +173,176 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor):
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class SamplerExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerAdvancedBasicPipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Store vae_decode setting for later use in update
|
||||
if inputs and "vae_decode" in inputs:
|
||||
if SAMPLING not in metadata:
|
||||
metadata[SAMPLING] = {}
|
||||
|
||||
if node_id not in metadata[SAMPLING]:
|
||||
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
|
||||
|
||||
# Store the vae_decode setting
|
||||
metadata[SAMPLING][node_id]["vae_decode"] = inputs["vae_decode"]
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Check if vae_decode was set to "true"
|
||||
should_save_image = True
|
||||
if SAMPLING in metadata and node_id in metadata[SAMPLING]:
|
||||
vae_decode = metadata[SAMPLING][node_id].get("vae_decode")
|
||||
if vae_decode is not None:
|
||||
should_save_image = (vae_decode == "true")
|
||||
|
||||
# Skip image saving if vae_decode isn't "true"
|
||||
if not should_save_image:
|
||||
return
|
||||
|
||||
# Ensure IMAGES category exists
|
||||
if IMAGES not in metadata:
|
||||
metadata[IMAGES] = {}
|
||||
|
||||
# Extract output_images from the TSC sampler format
|
||||
# outputs = [{"ui": {"images": preview_images}, "result": result}]
|
||||
# where result = (original_model, original_positive, original_negative, latent_list, optional_vae, output_images,)
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
# Get the first item in the list
|
||||
output_item = outputs[0]
|
||||
if isinstance(output_item, dict) and "result" in output_item:
|
||||
result = output_item["result"]
|
||||
if isinstance(result, tuple) and len(result) >= 6:
|
||||
# The output_images is the last element in the result tuple
|
||||
output_images = (result[5],)
|
||||
|
||||
# Save image data under node ID index to be captured by caching mechanism
|
||||
metadata[IMAGES][node_id] = {
|
||||
"node_id": node_id,
|
||||
"image": output_images
|
||||
}
|
||||
|
||||
# Only set first_decode if it hasn't been recorded yet
|
||||
if "first_decode" not in metadata[IMAGES]:
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
|
||||
|
||||
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
|
||||
class LoraLoaderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
@@ -292,7 +512,7 @@ class BasicSchedulerExtractor(NodeMetadataExtractor):
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||
class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
@@ -311,26 +531,8 @@ class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
@@ -376,6 +578,13 @@ class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -393,17 +602,64 @@ class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
||||
|
||||
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Save the original conditioning inputs
|
||||
base_positive = inputs.get("base_positive")
|
||||
base_negative = inputs.get("base_negative")
|
||||
|
||||
if base_positive is not None or base_negative is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
|
||||
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Extract transformed conditionings from outputs
|
||||
# outputs structure: [(base_positive, base_negative, show_help, )]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 2:
|
||||
transformed_positive = first_output[0]
|
||||
transformed_negative = first_output[1]
|
||||
|
||||
# Save transformed conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
|
||||
|
||||
# Registry of node-specific extractors
|
||||
# Keys are node class names
|
||||
NODE_EXTRACTORS = {
|
||||
# Sampling
|
||||
"KSampler": SamplerExtractor,
|
||||
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, # Updated to use dedicated extractor
|
||||
"SamplerCustom": KSamplerAdvancedExtractor,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||
"ClownsharKSampler_Beta": SamplerExtractor,
|
||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack
|
||||
"KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack
|
||||
# Sampling Selectors
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||
# Loaders
|
||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
@@ -412,6 +668,10 @@ NODE_EXTRACTORS = {
|
||||
"CLIPTextEncode": CLIPTextEncodeExtractor,
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
|
||||
1
py/middleware/__init__.py
Normal file
1
py/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Server middleware modules"""
|
||||
53
py/middleware/cache_middleware.py
Normal file
53
py/middleware/cache_middleware.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Cache control middleware for ComfyUI server"""
|
||||
|
||||
from aiohttp import web
|
||||
from typing import Callable, Awaitable
|
||||
|
||||
# Time in seconds
|
||||
ONE_HOUR: int = 3600
|
||||
ONE_DAY: int = 86400
|
||||
IMG_EXTENSIONS = (
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".ppm",
|
||||
".bmp",
|
||||
".pgm",
|
||||
".tif",
|
||||
".tiff",
|
||||
".webp",
|
||||
".mp4"
|
||||
)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(
|
||||
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> web.Response:
|
||||
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
|
||||
response: web.Response = await handler(request)
|
||||
|
||||
if (
|
||||
request.path.endswith(".js")
|
||||
or request.path.endswith(".css")
|
||||
or request.path.endswith("index.json")
|
||||
):
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
if not request.path.lower().endswith(IMG_EXTENSIONS):
|
||||
return response
|
||||
|
||||
# Handle image files
|
||||
if response.status == 404:
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
|
||||
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
|
||||
# Success responses and permanent redirects - cache for 1 day
|
||||
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
|
||||
elif response.status in (302, 303, 307):
|
||||
# Temporary redirects - no cache
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
# Note: 304 Not Modified falls through - no cache headers set
|
||||
|
||||
return response
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from server import PromptServer # type: ignore
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -7,6 +8,7 @@ class DebugMetadata:
|
||||
NAME = "Debug Metadata (LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Debug node to verify metadata_processor functionality"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -19,8 +21,7 @@ class DebugMetadata:
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("metadata_json",)
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "process_metadata"
|
||||
|
||||
def process_metadata(self, images, id):
|
||||
@@ -32,7 +33,13 @@ class DebugMetadata:
|
||||
# Use the MetadataProcessor to convert it to JSON string
|
||||
metadata_json = MetadataProcessor.to_json(metadata, id)
|
||||
|
||||
return (metadata_json,)
|
||||
# Send metadata to frontend for display
|
||||
PromptServer.instance.send_sync("metadata_update", {
|
||||
"id": id,
|
||||
"metadata": metadata_json
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing metadata: {e}")
|
||||
return ("{}",) # Return empty JSON object in case of error
|
||||
|
||||
return ()
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import logging
|
||||
import re
|
||||
from nodes import LoraLoader
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import asyncio
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraManagerLoader:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@@ -17,7 +18,8 @@ class LoraManagerLoader:
|
||||
"model": ("MODEL",),
|
||||
# "clip": ("CLIP",),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -37,19 +39,39 @@ class LoraManagerLoader:
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the provided path and strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
@@ -66,13 +88,19 @@ class LoraManagerLoader:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the resolved path with separate strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
@@ -83,6 +111,144 @@ class LoraManagerLoader:
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0]
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
class LoraManagerTextLoader:
|
||||
NAME = "LoRA Text Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": (IO.STRING, {
|
||||
"defaultInput": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
"""Parse LoRA syntax from text input."""
|
||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
|
||||
loras = []
|
||||
for match in matches:
|
||||
lora_name = match[0]
|
||||
model_strength = float(match[1])
|
||||
clip_strength = float(match[2]) if match[2] else model_strength
|
||||
|
||||
loras.append({
|
||||
'name': lora_name,
|
||||
'model_strength': model_strength,
|
||||
'clip_strength': clip_strength
|
||||
})
|
||||
|
||||
return loras
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
"""Load LoRAs based on text syntax input."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Parse and process LoRAs from text syntax
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora['name']
|
||||
model_strength = lora['model_strength']
|
||||
clip_strength = lora['clip_strength']
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,6 +17,7 @@ class LoraStacker:
|
||||
"required": {
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -43,7 +43,7 @@ class LoraStacker:
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
@@ -58,7 +58,7 @@ class LoraStacker:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Add to stack without loading
|
||||
# replace '/' with os.sep to avoid different OS path format
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import re
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
@@ -71,25 +69,20 @@ class SaveImage:
|
||||
FUNCTION = "process_image"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
async def get_lora_hash(self, lora_name):
|
||||
def get_lora_hash(self, lora_name):
|
||||
"""Get the lora hash from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||
|
||||
# Use the new direct filename lookup method
|
||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
return item.get('sha256')
|
||||
return None
|
||||
|
||||
async def get_checkpoint_hash(self, checkpoint_path):
|
||||
def get_checkpoint_hash(self, checkpoint_path):
|
||||
"""Get the checkpoint hash from cache"""
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
@@ -102,18 +95,10 @@ class SaveImage:
|
||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
normalized_path = checkpoint_path.replace('\\', '/')
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
||||
return item.get('sha256')
|
||||
|
||||
return None
|
||||
|
||||
async def format_metadata(self, metadata_dict):
|
||||
def format_metadata(self, metadata_dict):
|
||||
"""Format metadata in the requested format similar to userComment example"""
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
@@ -140,7 +125,7 @@ class SaveImage:
|
||||
|
||||
# Get hash for each lora
|
||||
for lora_name, strength in lora_matches:
|
||||
hash_value = await self.get_lora_hash(lora_name)
|
||||
hash_value = self.get_lora_hash(lora_name)
|
||||
if hash_value:
|
||||
lora_hashes[lora_name] = hash_value
|
||||
else:
|
||||
@@ -226,7 +211,7 @@ class SaveImage:
|
||||
checkpoint = metadata_dict.get('checkpoint')
|
||||
if checkpoint is not None:
|
||||
# Get model hash
|
||||
model_hash = await self.get_checkpoint_hash(checkpoint)
|
||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||
|
||||
# Extract basename without path
|
||||
checkpoint_name = os.path.basename(checkpoint)
|
||||
@@ -329,8 +314,7 @@ class SaveImage:
|
||||
raw_metadata = get_metadata()
|
||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||
|
||||
# Get or create metadata asynchronously
|
||||
metadata = asyncio.run(self.format_metadata(metadata_dict))
|
||||
metadata = self.format_metadata(metadata_dict)
|
||||
|
||||
# Process filename_prefix with pattern substitution
|
||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||
@@ -434,11 +418,15 @@ class SaveImage:
|
||||
# Make sure the output directory exists
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
# If images is already a list or array of images, do nothing; otherwise, convert to list
|
||||
if isinstance(images, (list, np.ndarray)):
|
||||
pass
|
||||
else:
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
|
||||
# Save all images
|
||||
results = self.save_images(
|
||||
|
||||
@@ -50,15 +50,9 @@ class TriggerWordToggle:
|
||||
|
||||
def process_trigger_words(self, id, group_mode, default_active, **kwargs):
|
||||
# Handle both old and new formats for trigger_words
|
||||
trigger_words_data = self._get_toggle_data(kwargs, 'trigger_words')
|
||||
trigger_words_data = self._get_toggle_data(kwargs, 'orinalMessage')
|
||||
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||
|
||||
# Send trigger words to frontend
|
||||
# PromptServer.instance.send_sync("trigger_word_update", {
|
||||
# "id": id,
|
||||
# "message": trigger_words
|
||||
# })
|
||||
|
||||
filtered_triggers = trigger_words
|
||||
|
||||
# Get toggle data with support for both formats
|
||||
|
||||
@@ -35,31 +35,11 @@ any_type = AnyType("*")
|
||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import copy
|
||||
import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_lora_info(lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
@@ -81,4 +61,70 @@ def get_loras_list(kwargs):
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
return []
|
||||
|
||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
if filter_prefix and not k.startswith(filter_prefix):
|
||||
continue
|
||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
def to_diffusers(input_lora):
|
||||
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
|
||||
if isinstance(input_lora, str):
|
||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||
else:
|
||||
tensors = {k: v for k, v in input_lora.items()}
|
||||
|
||||
# Convert FP8 tensors to BF16
|
||||
for k, v in tensors.items():
|
||||
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensors[k] = v.to(torch.bfloat16)
|
||||
|
||||
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||
|
||||
return new_tensors
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
"""Load a Flux LoRA for Nunchaku model"""
|
||||
model_wrapper = model.model.diffusion_model
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
|
||||
# Get full path to the LoRA file
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
# Convert the LoRA to diffusers format
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
# Handle embedding adjustment if needed
|
||||
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||
assert new_in_channels % 4 == 0
|
||||
new_in_channels = new_in_channels // 4
|
||||
|
||||
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||
if old_in_channels < new_in_channels:
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
98
py/nodes/wanvideo_lora_select.py
Normal file
98
py/nodes/wanvideo_lora_select.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WanVideoLoraSelect:
|
||||
NAME = "WanVideo Lora Select (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras"
|
||||
|
||||
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
# Process existing prev_lora if available
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_loras:
|
||||
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
|
||||
|
||||
# Get blocks if available
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_from_widget = get_loras_list(kwargs)
|
||||
for lora in loras_from_widget:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Create lora item for WanVideo format
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_loras,
|
||||
}
|
||||
|
||||
# Add to list and collect active loras
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format trigger_words for output
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format active_loras for output
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
127
py/nodes/wanvideo_lora_select_from_text.py
Normal file
127
py/nodes/wanvideo_lora_select_from_text.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from comfy.comfy_types import IO
|
||||
import folder_paths
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import any_type
|
||||
import logging
|
||||
|
||||
# 初始化日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 定义新节点的类
|
||||
class WanVideoLoraSelectFromText:
|
||||
# 节点在UI中显示的名称
|
||||
NAME = "WanVideo Lora Select From Text (LoraManager)"
|
||||
# 节点所属的分类
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_lora": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"lora_syntax": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"defaultInput": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
|
||||
"optional": {
|
||||
"prev_lora": ("WANVIDLORA",),
|
||||
"blocks": ("BLOCKS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
|
||||
FUNCTION = "process_loras_from_syntax"
|
||||
|
||||
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
|
||||
text_to_process = lora_syntax
|
||||
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_lora:
|
||||
low_mem_load = False
|
||||
|
||||
parts = text_to_process.split('<lora:')
|
||||
for part in parts[1:]:
|
||||
end_index = part.find('>')
|
||||
if end_index == -1:
|
||||
continue
|
||||
|
||||
content = part[:end_index]
|
||||
lora_parts = content.split(':')
|
||||
|
||||
lora_name_raw = ""
|
||||
model_strength = 1.0
|
||||
clip_strength = 1.0
|
||||
|
||||
if len(lora_parts) == 2:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = model_strength
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strength for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
elif len(lora_parts) >= 3:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = float(lora_parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strengths for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_path, trigger_words = get_lora_info(lora_name_raw)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_lora,
|
||||
}
|
||||
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name_raw, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": "WanVideo Lora Select From Text (LoraManager)"
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from ..config import config
|
||||
from .constants import VALID_LORA_TYPES
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,7 +55,7 @@ class RecipeMetadataParser(ABC):
|
||||
# Unpack the tuple to get the actual data
|
||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
if not civitai_info or civitai_info.get("error") == "Model not found":
|
||||
if not civitai_info or error_msg == "Model not found":
|
||||
# Model not found or deleted
|
||||
lora_entry['isDeleted'] = True
|
||||
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
|
||||
@@ -79,6 +79,9 @@ class RecipeMetadataParser(ABC):
|
||||
if 'model' in civitai_info and 'name' in civitai_info['model']:
|
||||
lora_entry['name'] = civitai_info['model']['name']
|
||||
|
||||
lora_entry['id'] = civitai_info.get('id')
|
||||
lora_entry['modelId'] = civitai_info.get('modelId')
|
||||
|
||||
# Update version if available
|
||||
if 'name' in civitai_info:
|
||||
lora_entry['version'] = civitai_info.get('name', '')
|
||||
@@ -116,10 +119,10 @@ class RecipeMetadataParser(ABC):
|
||||
# Check if exists locally
|
||||
if recipe_scanner and lora_entry['hash']:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_lora_hash(lora_entry['hash'])
|
||||
exists_locally = lora_scanner.has_hash(lora_entry['hash'])
|
||||
if exists_locally:
|
||||
try:
|
||||
local_path = lora_scanner.get_lora_path_by_hash(lora_entry['hash'])
|
||||
local_path = lora_scanner.get_path_by_hash(lora_entry['hash'])
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['localPath'] = local_path
|
||||
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Constants used across recipe parsers."""
|
||||
|
||||
# Import VALID_LORA_TYPES from utils.constants
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
|
||||
# Constants for generation parameters
|
||||
GEN_PARAM_KEYS = [
|
||||
'prompt',
|
||||
@@ -11,6 +14,3 @@ GEN_PARAM_KEYS = [
|
||||
'size',
|
||||
'clip_skip',
|
||||
]
|
||||
|
||||
# Valid Lora types
|
||||
VALID_LORA_TYPES = ['lora', 'locon']
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,7 +20,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
||||
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
||||
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([a-zA-Z0-9_\.\-]+):([0-9.]+)>'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
|
||||
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
||||
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
||||
|
||||
@@ -30,6 +31,9 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Automatic1111 format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Split on Negative prompt if it exists
|
||||
if "Negative prompt:" in user_comment:
|
||||
parts = user_comment.split('Negative prompt:', 1)
|
||||
@@ -181,13 +185,30 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
# First use Civitai resources if available (more reliable source)
|
||||
if metadata.get("civitai_resources"):
|
||||
for resource in metadata.get("civitai_resources", []):
|
||||
# --- Added: Parse 'air' field if present ---
|
||||
air = resource.get("air")
|
||||
if air:
|
||||
# Format: urn:air:sdxl:lora:civitai:1221007@1375651
|
||||
# Or: urn:air:sdxl:checkpoint:civitai:623891@2019115
|
||||
air_pattern = r"urn:air:[^:]+:(?P<type>[^:]+):civitai:(?P<modelId>\d+)@(?P<modelVersionId>\d+)"
|
||||
air_match = re.match(air_pattern, air)
|
||||
if air_match:
|
||||
air_type = air_match.group("type")
|
||||
air_modelId = int(air_match.group("modelId"))
|
||||
air_modelVersionId = int(air_match.group("modelVersionId"))
|
||||
# checkpoint/lycoris/lora/hypernet
|
||||
resource["type"] = air_type
|
||||
resource["modelId"] = air_modelId
|
||||
resource["modelVersionId"] = air_modelVersionId
|
||||
# --- End added ---
|
||||
|
||||
if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"):
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': str(resource.get("modelVersionId")),
|
||||
'modelId': str(resource.get("modelId")) if resource.get("modelId") else None,
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'version': resource.get("modelVersionName", resource.get("versionName", "")),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
@@ -199,9 +220,9 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
}
|
||||
|
||||
# Get additional info from Civitai
|
||||
if civitai_client:
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_version_info(resource.get("modelVersionId"))
|
||||
civitai_info = await metadata_provider.get_model_version_info(resource.get("modelVersionId"))
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
@@ -254,11 +275,11 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
}
|
||||
|
||||
# Try to get info from Civitai
|
||||
if civitai_client:
|
||||
if metadata_provider:
|
||||
try:
|
||||
if lora_hash:
|
||||
# If we have hash, use it for lookup
|
||||
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
else:
|
||||
civitai_info = None
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
from typing import Dict, Any, Union
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,12 +37,15 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
Args:
|
||||
metadata: The metadata from the image (dict)
|
||||
recipe_scanner: Optional recipe scanner service
|
||||
civitai_client: Optional Civitai API client
|
||||
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead)
|
||||
|
||||
Returns:
|
||||
Dict containing parsed recipe data
|
||||
"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Initialize result structure
|
||||
result = {
|
||||
'base_model': None,
|
||||
@@ -50,6 +54,17 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
'from_civitai_image': True
|
||||
}
|
||||
|
||||
# Track already added LoRAs to prevent duplicates
|
||||
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
|
||||
|
||||
# Extract hash information from hashes field for LoRA matching
|
||||
lora_hashes = {}
|
||||
if "hashes" in metadata and isinstance(metadata["hashes"], dict):
|
||||
for key, hash_value in metadata["hashes"].items():
|
||||
if key.startswith("LORA:"):
|
||||
lora_name = key.replace("LORA:", "")
|
||||
lora_hashes[lora_name] = hash_value
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
if "prompt" in metadata:
|
||||
result["gen_params"]["prompt"] = metadata["prompt"]
|
||||
@@ -74,9 +89,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
# Extract base model information - directly if available
|
||||
if "baseModel" in metadata:
|
||||
result["base_model"] = metadata["baseModel"]
|
||||
elif "Model hash" in metadata and civitai_client:
|
||||
elif "Model hash" in metadata and metadata_provider:
|
||||
model_hash = metadata["Model hash"]
|
||||
model_info = await civitai_client.get_model_by_hash(model_hash)
|
||||
model_info, error = await metadata_provider.get_model_by_hash(model_hash)
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
|
||||
@@ -84,8 +99,8 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
for resource in metadata.get("resources", []):
|
||||
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
|
||||
# This is likely the checkpoint model
|
||||
if civitai_client and resource.get("hash"):
|
||||
model_info = await civitai_client.get_model_by_hash(resource.get("hash"))
|
||||
if metadata_provider and resource.get("hash"):
|
||||
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash"))
|
||||
if model_info:
|
||||
result["base_model"] = model_info.get("baseModel", "")
|
||||
|
||||
@@ -96,11 +111,26 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
for resource in metadata["resources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type", "lora") == "lora":
|
||||
lora_hash = resource.get("hash", "")
|
||||
|
||||
# Try to get hash from the hashes field if not present in resource
|
||||
if not lora_hash and resource.get("name"):
|
||||
lora_hash = lora_hashes.get(resource["name"], "")
|
||||
|
||||
# Skip LoRAs without proper identification (hash or modelVersionId)
|
||||
if not lora_hash and not resource.get("modelVersionId"):
|
||||
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
|
||||
continue
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': resource.get("name", "Unknown LoRA"),
|
||||
'type': "lora",
|
||||
'weight': float(resource.get("weight", 1.0)),
|
||||
'hash': resource.get("hash", ""),
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("name", "Unknown"),
|
||||
@@ -112,10 +142,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and civitai_client:
|
||||
if lora_entry['hash'] and metadata_provider:
|
||||
try:
|
||||
lora_hash = lora_entry['hash']
|
||||
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
@@ -129,114 +158,193 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process civitaiResources array
|
||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
||||
for resource in metadata["civitaiResources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
||||
# Initialize lora entry with the same structure as in automatic.py
|
||||
lora_entry = {
|
||||
'id': str(resource.get("modelVersionId")),
|
||||
'modelId': str(resource.get("modelId")) if resource.get("modelId") else None,
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
# Get unique identifier for deduplication
|
||||
version_id = str(resource.get("modelVersionId", ""))
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id and version_id in added_loras:
|
||||
continue
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
||||
|
||||
# Track this LoRA in our deduplication dict
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if resource.get('modelVersionId') and civitai_client:
|
||||
try:
|
||||
version_id = str(resource.get('modelVersionId'))
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
continue
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {resource.get('modelVersionId')}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
# Skip resources that aren't LoRAs or LyCORIS
|
||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||
continue
|
||||
|
||||
# Extract ID from URN format if available
|
||||
model_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
model_id = parts[1]
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a model ID and civitai client, try to get more info
|
||||
if model_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info with the model ID
|
||||
civitai_info, error = await civitai_client.get_model_version_info(model_id)
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
|
||||
# Extract ID from URN format if available
|
||||
version_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
version_id = parts[1]
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a version ID and metadata provider, try to get more info
|
||||
if version_id and metadata_provider:
|
||||
try:
|
||||
# Use get_model_version_info with the version ID
|
||||
civitai_info = await metadata_provider.get_model_version_info(version_id)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
else:
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {model_id}: {e}")
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
|
||||
lora_index = 0
|
||||
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
|
||||
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
|
||||
lora_name = metadata[f"Lora_{lora_index} Model name"]
|
||||
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
lora_index += 1
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': "lora",
|
||||
'weight': lora_strength_model,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and metadata_provider:
|
||||
try:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
lora_index += 1
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
lora_index += 1
|
||||
|
||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||
if not result["base_model"] and base_model_counts:
|
||||
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,6 +27,9 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from Civitai ComfyUI metadata format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
data = json.loads(user_comment)
|
||||
loras = []
|
||||
|
||||
@@ -73,10 +77,10 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Get additional info from Civitai if client is available
|
||||
if civitai_client:
|
||||
# Get additional info from Civitai if metadata provider is available
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(model_version_id)
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
@@ -116,9 +120,9 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
||||
}
|
||||
|
||||
# Get additional checkpoint info from Civitai
|
||||
if civitai_client:
|
||||
if metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(checkpoint_version_id)
|
||||
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
# Populate checkpoint with Civitai info
|
||||
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
from typing import Dict, Any
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,8 +19,11 @@ class MetaFormatParser(RecipeMetadataParser):
|
||||
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
|
||||
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from images with meta format metadata"""
|
||||
"""Parse metadata from images with meta format metadata (Lora_N Model hash format)"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
parts = user_comment.split('Negative prompt:', 1)
|
||||
prompt = parts[0].strip()
|
||||
@@ -122,9 +126,9 @@ class MetaFormatParser(RecipeMetadataParser):
|
||||
}
|
||||
|
||||
# Get info from Civitai by hash if available
|
||||
if civitai_client and hash_value:
|
||||
if metadata_provider and hash_value:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_by_hash(hash_value)
|
||||
civitai_info = await metadata_provider.get_model_by_hash(hash_value)
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Any
|
||||
from ...config import config
|
||||
from ..base import RecipeMetadataParser
|
||||
from ..constants import GEN_PARAM_KEYS
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,6 +24,9 @@ class RecipeFormatParser(RecipeMetadataParser):
|
||||
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
|
||||
"""Parse metadata from images with dedicated recipe metadata format"""
|
||||
try:
|
||||
# Get metadata provider instead of using civitai_client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Extract recipe metadata from user comment
|
||||
try:
|
||||
# Look for recipe metadata section
|
||||
@@ -43,7 +47,7 @@ class RecipeFormatParser(RecipeMetadataParser):
|
||||
for lora in recipe_metadata.get('loras', []):
|
||||
# Convert recipe lora format to frontend format
|
||||
lora_entry = {
|
||||
'id': lora.get('modelVersionId', ''),
|
||||
'id': int(lora.get('modelVersionId', 0)),
|
||||
'name': lora.get('modelName', ''),
|
||||
'version': lora.get('modelVersionName', ''),
|
||||
'type': 'lora',
|
||||
@@ -55,7 +59,7 @@ class RecipeFormatParser(RecipeMetadataParser):
|
||||
# Check if this LoRA exists locally by SHA256 hash
|
||||
if lora.get('hash') and recipe_scanner:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_lora_hash(lora['hash'])
|
||||
exists_locally = lora_scanner.has_hash(lora['hash'])
|
||||
if exists_locally:
|
||||
lora_cache = await lora_scanner.get_cached_data()
|
||||
lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None)
|
||||
@@ -71,9 +75,9 @@ class RecipeFormatParser(RecipeMetadataParser):
|
||||
lora_entry['localPath'] = None
|
||||
|
||||
# Try to get additional info from Civitai if we have a model version ID
|
||||
if lora.get('modelVersionId') and civitai_client:
|
||||
if lora.get('modelVersionId') and metadata_provider:
|
||||
try:
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
||||
civitai_info_tuple = await metadata_provider.get_model_version_info(lora['modelVersionId'])
|
||||
# Populate lora entry with Civitai info
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
275
py/routes/base_model_routes.py
Normal file
275
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..services.download_coordinator import DownloadCoordinator
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.model_lifecycle_service import ModelLifecycleService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.use_cases import (
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelUseCase,
|
||||
)
|
||||
from ..services.websocket_progress_callback import (
|
||||
WebSocketBroadcastCallback,
|
||||
WebSocketProgressCallback,
|
||||
)
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||
from .handlers.model_handlers import (
|
||||
ModelAutoOrganizeHandler,
|
||||
ModelCivitaiHandler,
|
||||
ModelDownloadHandler,
|
||||
ModelHandlerSet,
|
||||
ModelListingHandler,
|
||||
ModelManagementHandler,
|
||||
ModelMoveHandler,
|
||||
ModelPageView,
|
||||
ModelQueryHandler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModelRoutes(ABC):
|
||||
"""Base route controller for all model types."""
|
||||
|
||||
template_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service=None,
|
||||
*,
|
||||
settings_service=default_settings,
|
||||
ws_manager=default_ws_manager,
|
||||
server_i18n=default_server_i18n,
|
||||
metadata_provider_factory=get_default_metadata_provider,
|
||||
) -> None:
|
||||
self.service = None
|
||||
self.model_type = ""
|
||||
self._settings = settings_service
|
||||
self._ws_manager = ws_manager
|
||||
self._server_i18n = server_i18n
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self.model_file_service: ModelFileService | None = None
|
||||
self.model_move_service: ModelMoveService | None = None
|
||||
self.model_lifecycle_service: ModelLifecycleService | None = None
|
||||
self.websocket_progress_callback = WebSocketProgressCallback()
|
||||
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
||||
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
self._preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=settings_service,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||
self._download_coordinator = DownloadCoordinator(
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
def attach_service(self, service) -> None:
|
||||
"""Attach a model service and rebuild handler dependencies."""
|
||||
self.service = service
|
||||
self.model_type = service.model_type
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
def _create_handler_set(self) -> ModelHandlerSet:
|
||||
service = self._ensure_service()
|
||||
page_view = ModelPageView(
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name or "",
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
server_i18n=self._server_i18n,
|
||||
logger=logger,
|
||||
)
|
||||
listing = ModelListingHandler(
|
||||
service=service,
|
||||
parse_specific_params=self._parse_specific_params,
|
||||
logger=logger,
|
||||
)
|
||||
management = ModelManagementHandler(
|
||||
service=service,
|
||||
logger=logger,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
preview_service=self._preview_service,
|
||||
tag_update_service=self._tag_update_service,
|
||||
lifecycle_service=self._ensure_lifecycle_service(),
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_use_case=download_use_case,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
settings_service=self._settings,
|
||||
logger=logger,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
validate_model_type=self._validate_civitai_model_type,
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
metadata_refresh_use_case=metadata_refresh_use_case,
|
||||
metadata_progress_callback=self.metadata_progress_callback,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize_use_case = AutoOrganizeUseCase(
|
||||
file_service=self._ensure_file_service(),
|
||||
lock_provider=self._ws_manager,
|
||||
)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
use_case=auto_organize_use_case,
|
||||
progress_callback=self.websocket_progress_callback,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
return ModelHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
management=management,
|
||||
query=query,
|
||||
download=download,
|
||||
civitai=civitai,
|
||||
move=move,
|
||||
auto_organize=auto_organize,
|
||||
)
|
||||
|
||||
@property
|
||||
def route_handlers(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
return self._ensure_handler_mapping()
|
||||
|
||||
def setup_routes(self, app: web.Application, prefix: str) -> None:
|
||||
registrar = ModelRouteRegistrar(app)
|
||||
handler_lookup = {
|
||||
definition.handler_name: self._make_handler_proxy(definition.handler_name)
|
||||
for definition in COMMON_ROUTE_DEFINITIONS
|
||||
}
|
||||
registrar.register_common_routes(prefix, handler_lookup)
|
||||
self.setup_specific_routes(registrar, prefix)
|
||||
|
||||
@abstractmethod
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str) -> None:
|
||||
"""Setup model-specific routes."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse model-specific parameters - to be overridden by subclasses."""
|
||||
return {}
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type - to be overridden by subclasses."""
|
||||
return True
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages - to be overridden by subclasses."""
|
||||
return "any model type"
|
||||
|
||||
def _find_model_file(self, files):
|
||||
"""Find the appropriate model file from the files list - can be overridden by subclasses."""
|
||||
return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None)
|
||||
|
||||
def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
"""Expose handlers for subclasses or tests."""
|
||||
return self._ensure_handler_mapping()[name]
|
||||
|
||||
def _ensure_service(self):
|
||||
if self.service is None:
|
||||
raise RuntimeError("Model service has not been attached")
|
||||
return self.service
|
||||
|
||||
def _ensure_file_service(self) -> ModelFileService:
|
||||
if self.model_file_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
return self.model_file_service
|
||||
|
||||
def _ensure_move_service(self) -> ModelMoveService:
|
||||
if self.model_move_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
return self.model_move_service
|
||||
|
||||
def _ensure_lifecycle_service(self) -> ModelLifecycleService:
|
||||
if self.model_lifecycle_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
return self.model_lifecycle_service
|
||||
|
||||
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
handler = self.get_handler(name)
|
||||
except RuntimeError:
|
||||
return web.json_response({"success": False, "error": "Service not ready"}, status=503)
|
||||
return await handler(request)
|
||||
|
||||
return proxy
|
||||
|
||||
217
py/routes/base_recipe_routes.py
Normal file
217
py/routes/base_recipe_routes.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Base infrastructure shared across recipe routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
)
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
RecipeAnalysisHandler,
|
||||
RecipeHandlerSet,
|
||||
RecipeListingHandler,
|
||||
RecipeManagementHandler,
|
||||
RecipePageView,
|
||||
RecipeQueryHandler,
|
||||
RecipeSharingHandler,
|
||||
)
|
||||
from .recipe_route_registrar import ROUTE_DEFINITIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRecipeRoutes:
|
||||
"""Common dependency and startup wiring for recipe routes."""
|
||||
|
||||
_HANDLER_NAMES: tuple[str, ...] = tuple(
|
||||
definition.handler_name for definition in ROUTE_DEFINITIONS
|
||||
)
|
||||
|
||||
template_name: str = "recipes.html"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.recipe_scanner = None
|
||||
self.lora_scanner = None
|
||||
self.civitai_client = None
|
||||
self.settings = settings
|
||||
self.server_i18n = server_i18n
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self._i18n_registered = False
|
||||
self._startup_hooks_registered = False
|
||||
self._handler_set: RecipeHandlerSet | None = None
|
||||
self._handler_mapping: dict[str, Callable] | None = None
|
||||
|
||||
async def attach_dependencies(self, app: web.Application | None = None) -> None:
|
||||
"""Resolve shared services from the registry."""
|
||||
|
||||
await self._ensure_services()
|
||||
self._ensure_i18n_filter()
|
||||
|
||||
async def ensure_dependencies_ready(self) -> None:
|
||||
"""Ensure dependencies are available for request handlers."""
|
||||
|
||||
if self.recipe_scanner is None or self.civitai_client is None:
|
||||
await self.attach_dependencies()
|
||||
|
||||
def register_startup_hooks(self, app: web.Application) -> None:
|
||||
"""Register startup hooks once for dependency wiring."""
|
||||
|
||||
if self._startup_hooks_registered:
|
||||
return
|
||||
|
||||
app.on_startup.append(self.attach_dependencies)
|
||||
app.on_startup.append(self.prewarm_cache)
|
||||
self._startup_hooks_registered = True
|
||||
|
||||
async def prewarm_cache(self, app: web.Application | None = None) -> None:
|
||||
"""Pre-load recipe and LoRA caches on startup."""
|
||||
|
||||
try:
|
||||
await self.attach_dependencies(app)
|
||||
|
||||
if self.lora_scanner is not None:
|
||||
await self.lora_scanner.get_cached_data()
|
||||
hash_index = getattr(self.lora_scanner, "_hash_index", None)
|
||||
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
||||
_ = len(hash_index._hash_to_path)
|
||||
|
||||
if self.recipe_scanner is not None:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as exc:
|
||||
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable]:
|
||||
"""Return a mapping of handler name to coroutine for registrar binding."""
|
||||
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
async def _ensure_services(self) -> None:
|
||||
if self.recipe_scanner is None:
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None)
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
def _ensure_i18n_filter(self) -> None:
|
||||
if not self._i18n_registered:
|
||||
self.template_env.filters["t"] = self.server_i18n.create_template_filter()
|
||||
self._i18n_registered = True
|
||||
|
||||
def get_handler_owner(self):
|
||||
"""Return the object supplying bound handler coroutines."""
|
||||
|
||||
if self._handler_set is None:
|
||||
self._handler_set = self._create_handler_set()
|
||||
return self._handler_set
|
||||
|
||||
def _create_handler_set(self) -> RecipeHandlerSet:
|
||||
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||
civitai_client_getter = lambda: self.civitai_client
|
||||
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||
MetadataProcessor,
|
||||
)
|
||||
from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found]
|
||||
MetadataRegistry,
|
||||
)
|
||||
else: # pragma: no cover - optional dependency path
|
||||
get_metadata = None # type: ignore[assignment]
|
||||
MetadataProcessor = None # type: ignore[assignment]
|
||||
MetadataRegistry = None # type: ignore[assignment]
|
||||
|
||||
analysis_service = RecipeAnalysisService(
|
||||
exif_utils=ExifUtils,
|
||||
recipe_parser_factory=RecipeParserFactory,
|
||||
downloader_factory=get_downloader,
|
||||
metadata_collector=get_metadata,
|
||||
metadata_processor_cls=MetadataProcessor,
|
||||
metadata_registry_cls=MetadataRegistry,
|
||||
standalone_mode=standalone_mode,
|
||||
logger=logger,
|
||||
)
|
||||
persistence_service = RecipePersistenceService(
|
||||
exif_utils=ExifUtils,
|
||||
card_preview_width=CARD_PREVIEW_WIDTH,
|
||||
logger=logger,
|
||||
)
|
||||
sharing_service = RecipeSharingService(logger=logger)
|
||||
|
||||
page_view = RecipePageView(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
settings_service=self.settings,
|
||||
server_i18n=self.server_i18n,
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
listing = RecipeListingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
query = RecipeQueryHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
format_recipe_file_url=listing.format_recipe_file_url,
|
||||
logger=logger,
|
||||
)
|
||||
management = RecipeManagementHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
persistence_service=persistence_service,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
analysis = RecipeAnalysisHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
logger=logger,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
sharing = RecipeSharingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
sharing_service=sharing_service,
|
||||
)
|
||||
|
||||
return RecipeHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
query=query,
|
||||
management=management,
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
)
|
||||
|
||||
96
py/routes/checkpoint_routes.py
Normal file
96
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Checkpoint-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||
super().__init__()
|
||||
self.template_name = "checkpoints.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Checkpoint routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||
super().setup_routes(app, 'checkpoints')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Checkpoint-specific routes"""
|
||||
# Checkpoint info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info)
|
||||
|
||||
# Checkpoint roots and Unet roots
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Checkpoint"""
|
||||
return model_type.lower() == 'checkpoint'
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "Checkpoint"
|
||||
|
||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of checkpoint roots from config"""
|
||||
try:
|
||||
roots = config.checkpoints_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_unet_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of unet roots from config"""
|
||||
try:
|
||||
roots = config.unet_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
@@ -1,818 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointsRoutes:
|
||||
"""API routes for checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
|
||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||
|
||||
# Add new WebSocket endpoint for checkpoint progress
|
||||
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
|
||||
|
||||
# Add new routes for finding duplicates and filename conflicts
|
||||
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
|
||||
|
||||
# Add new endpoint for bulk deleting checkpoints
|
||||
app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints)
|
||||
|
||||
async def get_checkpoints(self, request):
|
||||
"""Get paginated checkpoint data"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
|
||||
|
||||
# Process search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Process hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Get data from scanner
|
||||
result = await self.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy_search=fuzzy_search,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
search_options=search_options,
|
||||
hash_filters=hash_filters,
|
||||
favorites_only=favorites_only # Pass favorites_only parameter
|
||||
)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
# Return as JSON
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||
folder=None, search=None, fuzzy_search=False,
|
||||
base_models=None, tags=None,
|
||||
search_options=None, hash_filters=None,
|
||||
favorites_only=False): # Add favorites_only parameter with default False
|
||||
"""Get paginated and filtered checkpoint data"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if any(tag in cp.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
|
||||
for cp in filtered_data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('file_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('file_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('model_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('model_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in cp:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _format_checkpoint_response(self, checkpoint):
|
||||
"""Format checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint["model_name"],
|
||||
"file_name": checkpoint["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint.get("base_model", ""),
|
||||
"folder": checkpoint["folder"],
|
||||
"sha256": checkpoint.get("sha256", ""),
|
||||
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint.get("size", 0),
|
||||
"modified": checkpoint.get("modified", ""),
|
||||
"tags": checkpoint.get("tags", []),
|
||||
"modelDescription": checkpoint.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint.get("from_civitai", True),
|
||||
"notes": checkpoint.get("notes", ""),
|
||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare checkpoints to process
|
||||
to_process = [
|
||||
cp for cp in cache.raw_data
|
||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each checkpoint
|
||||
for cp in to_process:
|
||||
try:
|
||||
original_name = cp.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=cp['sha256'],
|
||||
file_path=cp['file_path'],
|
||||
model_data=cp,
|
||||
update_cache_func=self.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != cp.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': cp.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get base models
|
||||
base_models = await self.scanner.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_checkpoints(self, request):
|
||||
"""Force a rescan of checkpoint files"""
|
||||
try:
|
||||
# Get the full_rebuild parameter and convert to bool, default to False
|
||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||
|
||||
await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild)
|
||||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoint_info(self, request):
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /checkpoints request"""
|
||||
try:
|
||||
# Check if the CheckpointScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
|
||||
logger.info("Checkpoints page is initializing, returning loading page")
|
||||
else:
|
||||
# 正常流程 - 获取已经初始化好的缓存数据
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||
# 如果获取缓存失败,也显示初始化页面
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Checkpoints cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading checkpoints page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model exclusion request"""
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
||||
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement for checkpoints"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint download request"""
|
||||
async with self._download_lock:
|
||||
# Get the download manager from service registry if not already initialized
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback that uses checkpoint-specific WebSocket
|
||||
async def progress_callback(progress):
|
||||
await ws_manager.broadcast_checkpoint_progress({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=data.get('checkpoint_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
progress_callback=progress_callback,
|
||||
model_type="checkpoint"
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
||||
)
|
||||
|
||||
logger.error(f"Error downloading checkpoint: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates for checkpoints"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Update metadata
|
||||
metadata.update(metadata_updates)
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get the civitai client from service registry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if (model_type.lower() != 'checkpoint'):
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with duplicate SHA256 hashes"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate hashes from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for sha256, paths in duplicates.items():
|
||||
group = {
|
||||
"hash": sha256,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Add the primary model too
|
||||
primary_path = self.scanner._hash_index.get_path(sha256)
|
||||
if primary_path and primary_path not in paths:
|
||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||
if primary_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(primary_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"duplicates": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with conflicting filenames"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate filenames from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_filenames()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for filename, paths in duplicates.items():
|
||||
group = {
|
||||
"filename": filename,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Find the model from the main index too
|
||||
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self.scanner._hash_index.get_path(hash_val)
|
||||
if main_path and main_path not in paths:
|
||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||
if main_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(main_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"conflicts": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk deletion of checkpoint models"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request by model version ID for checkpoints"""
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self.scanner)
|
||||
61
py/routes/embedding_routes.py
Normal file
61
py/routes/embedding_routes.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Embedding-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Embedding routes with Embedding service"""
|
||||
super().__init__()
|
||||
self.template_name = "embeddings.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.service = EmbeddingService(embedding_scanner)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Embedding routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'embeddings' prefix (includes page route)
|
||||
super().setup_routes(app, 'embeddings')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup Embedding-specific routes"""
|
||||
# Embedding info by name
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_embedding_info)
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for Embedding"""
|
||||
return model_type.lower() == 'textualinversion'
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "TextualInversion"
|
||||
|
||||
async def get_embedding_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific embedding by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
embedding_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if embedding_info:
|
||||
return web.json_response(embedding_info)
|
||||
else:
|
||||
return web.json_response({"error": "Embedding not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
62
py/routes/example_images_route_registrar.py
Normal file
62
py/routes/example_images_route_registrar.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Route registrar for example image endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative configuration for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"),
|
||||
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
|
||||
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
|
||||
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
|
||||
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"),
|
||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||
)
|
||||
|
||||
|
||||
class ExampleImagesRouteRegistrar:
|
||||
"""Bind declarative example image routes to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(
|
||||
self,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
"""Register each route definition using the supplied handlers."""
|
||||
|
||||
for definition in definitions:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
File diff suppressed because it is too large
Load Diff
159
py/routes/handlers/example_images_handlers.py
Normal file
159
py/routes/handlers/example_images_handlers.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Handler set for example image routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.use_cases.example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from ...utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
DownloadNotRunningError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from ...utils.example_images_processor import ExampleImagesImportError
|
||||
|
||||
|
||||
class ExampleImagesDownloadHandler:
|
||||
"""HTTP adapters for download-related example image endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
download_use_case: DownloadExampleImagesUseCase,
|
||||
download_manager,
|
||||
) -> None:
|
||||
self._download_use_case = download_use_case
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_use_case.execute(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadExampleImagesInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadExampleImagesConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._download_manager.get_status(request)
|
||||
return web.json_response(result)
|
||||
|
||||
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.pause_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.resume_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_manager.start_force_download(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress_snapshot,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
|
||||
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor, cleanup_service) -> None:
|
||||
self._import_use_case = import_use_case
|
||||
self._processor = processor
|
||||
self._cleanup_service = cleanup_service
|
||||
|
||||
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._import_use_case.execute(request)
|
||||
return web.json_response(result)
|
||||
except ImportExampleImagesValidationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesImportError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._processor.delete_custom_image(request)
|
||||
|
||||
async def cleanup_example_image_folders(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._cleanup_service.cleanup_example_image_folders()
|
||||
|
||||
if result.get('success') or result.get('partial_success'):
|
||||
return web.json_response(result, status=200)
|
||||
|
||||
error_code = result.get('error_code')
|
||||
status = 400 if error_code in {'path_not_configured', 'path_not_found'} else 500
|
||||
return web.json_response(result, status=status)
|
||||
|
||||
|
||||
class ExampleImagesFileHandler:
|
||||
"""HTTP adapters for filesystem-centric endpoints."""
|
||||
|
||||
def __init__(self, file_manager) -> None:
|
||||
self._file_manager = file_manager
|
||||
|
||||
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.open_folder(request)
|
||||
|
||||
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.get_files(request)
|
||||
|
||||
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.has_images(request)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExampleImagesHandlerSet:
|
||||
"""Aggregate of handlers exposed to the registrar."""
|
||||
|
||||
download: ExampleImagesDownloadHandler
|
||||
management: ExampleImagesManagementHandler
|
||||
files: ExampleImagesFileHandler
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Flatten handler methods into the registrar mapping."""
|
||||
|
||||
return {
|
||||
"download_example_images": self.download.download_example_images,
|
||||
"get_example_images_status": self.download.get_example_images_status,
|
||||
"pause_example_images": self.download.pause_example_images,
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
"cleanup_example_image_folders": self.management.cleanup_example_image_folders,
|
||||
"open_example_images_folder": self.files.open_example_images_folder,
|
||||
"get_example_image_files": self.files.get_example_image_files,
|
||||
"has_example_images": self.files.has_example_images,
|
||||
}
|
||||
1020
py/routes/handlers/model_handlers.py
Normal file
1020
py/routes/handlers/model_handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
725
py/routes/handlers/recipe_handlers.py
Normal file
725
py/routes/handlers/recipe_handlers.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""Dedicated handler objects for recipe-related routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.server_i18n import server_i18n as default_server_i18n
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
RecipeValidationError,
|
||||
)
|
||||
|
||||
Logger = logging.Logger
|
||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||
RecipeScannerGetter = Callable[[], Any]
|
||||
CivitaiClientGetter = Callable[[], Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RecipeHandlerSet:
|
||||
"""Group of handlers providing recipe route implementations."""
|
||||
|
||||
page_view: "RecipePageView"
|
||||
listing: "RecipeListingHandler"
|
||||
query: "RecipeQueryHandler"
|
||||
management: "RecipeManagementHandler"
|
||||
analysis: "RecipeAnalysisHandler"
|
||||
sharing: "RecipeSharingHandler"
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
"""Expose handler coroutines keyed by registrar handler names."""
|
||||
|
||||
return {
|
||||
"render_page": self.page_view.render_page,
|
||||
"list_recipes": self.listing.list_recipes,
|
||||
"get_recipe": self.listing.get_recipe,
|
||||
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||
"analyze_local_image": self.analysis.analyze_local_image,
|
||||
"save_recipe": self.management.save_recipe,
|
||||
"delete_recipe": self.management.delete_recipe,
|
||||
"get_top_tags": self.query.get_top_tags,
|
||||
"get_base_models": self.query.get_base_models,
|
||||
"share_recipe": self.sharing.share_recipe,
|
||||
"download_shared_recipe": self.sharing.download_shared_recipe,
|
||||
"get_recipe_syntax": self.query.get_recipe_syntax,
|
||||
"update_recipe": self.management.update_recipe,
|
||||
"reconnect_lora": self.management.reconnect_lora,
|
||||
"find_duplicates": self.query.find_duplicates,
|
||||
"bulk_delete": self.management.bulk_delete,
|
||||
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||
"scan_recipes": self.query.scan_recipes,
|
||||
}
|
||||
|
||||
|
||||
class RecipePageView:
|
||||
"""Render the recipe shell page."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
settings_service: SettingsManager,
|
||||
server_i18n=default_server_i18n,
|
||||
template_env,
|
||||
template_name: str,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._settings = settings_service
|
||||
self._server_i18n = server_i18n
|
||||
self._template_env = template_env
|
||||
self._template_name = template_name
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def render_page(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None: # pragma: no cover - defensive guard
|
||||
raise RuntimeError("Recipe scanner not available")
|
||||
|
||||
user_language = self._settings.get("language", "en")
|
||||
self._server_i18n.set_locale(user_language)
|
||||
|
||||
try:
|
||||
await recipe_scanner.get_cached_data(force_refresh=False)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
recipes=[],
|
||||
is_initializing=False,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
except Exception as cache_error: # pragma: no cover - logging path
|
||||
self._logger.error("Error loading recipe cache data: %s", cache_error)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
is_initializing=True,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
return web.Response(text=rendered, content_type="text/html")
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error handling recipes request: %s", exc, exc_info=True)
|
||||
return web.Response(text="Error loading recipes page", status=500)
|
||||
|
||||
|
||||
class RecipeListingHandler:
|
||||
"""Provide listing and detail APIs for recipes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def list_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
page = int(request.query.get("page", "1"))
|
||||
page_size = int(request.query.get("page_size", "20"))
|
||||
sort_by = request.query.get("sort_by", "date")
|
||||
search = request.query.get("search")
|
||||
|
||||
search_options = {
|
||||
"title": request.query.get("search_title", "true").lower() == "true",
|
||||
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
}
|
||||
|
||||
filters: Dict[str, list[str]] = {}
|
||||
base_models = request.query.get("base_models")
|
||||
if base_models:
|
||||
filters["base_model"] = base_models.split(",")
|
||||
|
||||
tags = request.query.get("tags")
|
||||
if tags:
|
||||
filters["tags"] = tags.split(",")
|
||||
|
||||
lora_hash = request.query.get("lora_hash")
|
||||
|
||||
result = await recipe_scanner.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
search=search,
|
||||
filters=filters,
|
||||
search_options=search_options,
|
||||
lora_hash=lora_hash,
|
||||
)
|
||||
|
||||
for item in result.get("items", []):
|
||||
file_path = item.get("file_path")
|
||||
if file_path:
|
||||
item["file_url"] = self.format_recipe_file_url(file_path)
|
||||
else:
|
||||
item.setdefault("file_url", "/loras_static/images/no-preview.png")
|
||||
item.setdefault("loras", [])
|
||||
item.setdefault("base_model", "")
|
||||
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
|
||||
if not recipe:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
return web.json_response(recipe)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def format_recipe_file_url(self, file_path: str) -> str:
|
||||
try:
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/")
|
||||
normalized_path = file_path.replace(os.sep, "/")
|
||||
if normalized_path.startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/")
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
|
||||
class RecipeQueryHandler:
|
||||
"""Provide read-only insights on recipe data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
format_recipe_file_url: Callable[[str], str],
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._format_recipe_file_url = format_recipe_file_url
|
||||
self._logger = logger
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
tag_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
for tag in recipe.get("tags", []) or []:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
|
||||
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving top tags: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
base_model = recipe.get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
lora_hash = request.query.get("hash")
|
||||
if not lora_hash:
|
||||
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||
|
||||
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
self._logger.info("Manually triggering recipe cache rebuild")
|
||||
await recipe_scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def find_duplicates(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
duplicate_groups = await recipe_scanner.find_all_duplicate_recipes()
|
||||
response_data = []
|
||||
|
||||
for fingerprint, recipe_ids in duplicate_groups.items():
|
||||
if len(recipe_ids) <= 1:
|
||||
continue
|
||||
|
||||
recipes = []
|
||||
for recipe_id in recipe_ids:
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if recipe:
|
||||
recipes.append(
|
||||
{
|
||||
"id": recipe.get("id"),
|
||||
"title": recipe.get("title"),
|
||||
"file_url": recipe.get("file_url")
|
||||
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||
"modified": recipe.get("modified"),
|
||||
"created_date": recipe.get("created_date"),
|
||||
"lora_count": len(recipe.get("loras", [])),
|
||||
}
|
||||
)
|
||||
|
||||
if len(recipes) >= 2:
|
||||
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||
response_data.append(
|
||||
{
|
||||
"fingerprint": fingerprint,
|
||||
"count": len(recipes),
|
||||
"recipes": recipes,
|
||||
}
|
||||
)
|
||||
|
||||
response_data.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "duplicate_groups": response_data})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
try:
|
||||
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
|
||||
except RecipeNotFoundError:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
if not syntax_parts:
|
||||
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||
|
||||
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class RecipeManagementHandler:
|
||||
"""Handle create/update/delete style recipe operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
persistence_service: RecipePersistenceService,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._persistence_service = persistence_service
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
reader = await request.multipart()
|
||||
payload = await self._parse_save_payload(reader)
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=payload["image_bytes"],
|
||||
image_base64=payload["image_base64"],
|
||||
name=payload["name"],
|
||||
tags=payload["tags"],
|
||||
metadata=payload["metadata"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._persistence_service.delete_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error deleting recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def update_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
data = await request.json()
|
||||
result = await self._persistence_service.update_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
for field in ("recipe_id", "lora_index", "target_name"):
|
||||
if field not in data:
|
||||
raise RecipeValidationError(f"Missing required field: {field}")
|
||||
|
||||
result = await self._persistence_service.reconnect_lora(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=data["recipe_id"],
|
||||
lora_index=int(data["lora_index"]),
|
||||
target_name=data["target_name"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def bulk_delete(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
recipe_ids = data.get("recipe_ids", [])
|
||||
result = await self._persistence_service.bulk_delete(
|
||||
recipe_scanner=recipe_scanner, recipe_ids=recipe_ids
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error performing bulk delete: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
analysis = await self._analysis_service.analyze_widget_metadata(
|
||||
recipe_scanner=recipe_scanner
|
||||
)
|
||||
metadata = analysis.payload.get("metadata")
|
||||
image_bytes = analysis.payload.get("image_bytes")
|
||||
if not metadata or image_bytes is None:
|
||||
raise RecipeValidationError("Unable to extract metadata from widget")
|
||||
|
||||
result = await self._persistence_service.save_recipe_from_widget(
|
||||
recipe_scanner=recipe_scanner,
|
||||
metadata=metadata,
|
||||
image_bytes=image_bytes,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def _parse_save_payload(self, reader) -> dict[str, Any]:
|
||||
image_bytes: Optional[bytes] = None
|
||||
image_base64: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = []
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
if field.name == "image":
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
image_bytes = bytes(image_chunks)
|
||||
elif field.name == "image_base64":
|
||||
image_base64 = await field.text()
|
||||
elif field.name == "name":
|
||||
name = await field.text()
|
||||
elif field.name == "tags":
|
||||
tags_text = await field.text()
|
||||
try:
|
||||
parsed_tags = json.loads(tags_text)
|
||||
tags = parsed_tags if isinstance(parsed_tags, list) else []
|
||||
except Exception:
|
||||
tags = []
|
||||
elif field.name == "metadata":
|
||||
metadata_text = await field.text()
|
||||
try:
|
||||
metadata = json.loads(metadata_text)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
return {
|
||||
"image_bytes": image_bytes,
|
||||
"image_base64": image_base64,
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
class RecipeAnalysisHandler:
|
||||
"""Analyze images to extract recipe metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
logger: Logger,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
self._logger = logger
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def analyze_uploaded_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
civitai_client = self._civitai_client_getter()
|
||||
if recipe_scanner is None or civitai_client is None:
|
||||
raise RuntimeError("Required services unavailable")
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "multipart/form-data" in content_type:
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "image":
|
||||
raise RecipeValidationError("No image field found")
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
result = await self._analysis_service.analyze_uploaded_image(
|
||||
image_bytes=bytes(image_chunks),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_remote_image(
|
||||
url=data.get("url"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
civitai_client=civitai_client,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
raise RecipeValidationError("Unsupported content type")
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
async def analyze_local_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_local_image(
|
||||
file_path=data.get("path"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing local image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
|
||||
class RecipeSharingHandler:
|
||||
"""Serve endpoints related to recipe sharing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
sharing_service: RecipeSharingService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._sharing_service = sharing_service
|
||||
|
||||
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._sharing_service.share_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error sharing recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
download_info = await self._sharing_service.prepare_download(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.FileResponse(
|
||||
download_info.file_path,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{download_info.download_filename}"'
|
||||
},
|
||||
)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
@@ -1,189 +1,245 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import logging
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from .model_route_registrar import ModelRouteRegistrar
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.scanner = None
|
||||
self.recipe_scanner = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
super().__init__()
|
||||
self.template_name = "loras.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
|
||||
# Attach service dependencies
|
||||
self.attach_service(self.service)
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the LoraScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info("Loras page is initializing, returning loading page")
|
||||
else:
|
||||
# Normal flow - get data from initialized cache
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
||||
try:
|
||||
# Return the file URL directly for the first lora root's preview
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
# If not in recipes dir, try to create a valid URL from the file path
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
|
||||
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||
|
||||
# ComfyUI integration
|
||||
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
# Register routes
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
return params
|
||||
|
||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||
"""Validate CivitAI model type for LoRA"""
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
return model_type.lower() in VALID_LORA_TYPES
|
||||
|
||||
def _get_expected_model_types(self) -> str:
|
||||
"""Get expected model types string for error messages"""
|
||||
return "LORA, LoCon, or DORA"
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
try:
|
||||
relative_path = request.query.get('relative_path')
|
||||
if not relative_path:
|
||||
return web.Response(text='Relative path is required', status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'usage_tips': usage_tips or ''
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -1,100 +1,223 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import asyncio
|
||||
import subprocess
|
||||
import re
|
||||
from server import PromptServer # type: ignore
|
||||
from aiohttp import web
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from ..utils.lora_metadata import extract_trained_words
|
||||
from ..config import config
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
import re
|
||||
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers, get_metadata_provider
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.downloader import get_downloader
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Download status tracking
|
||||
download_task = None
|
||||
is_downloading = False
|
||||
download_progress = {
|
||||
'total': 0,
|
||||
'completed': 0,
|
||||
'current_model': '',
|
||||
'status': 'idle', # idle, running, paused, completed, error
|
||||
'errors': [],
|
||||
'last_error': None,
|
||||
'start_time': None,
|
||||
'end_time': None,
|
||||
'processed_models': set(), # Track models that have been processed
|
||||
'refreshed_models': set() # Track models that had metadata refreshed
|
||||
}
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
# Node registry for tracking active workflow nodes
|
||||
class NodeRegistry:
|
||||
"""Thread-safe registry for tracking Lora nodes in active workflows"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.RLock()
|
||||
self._nodes = {} # node_id -> node_info
|
||||
self._registry_updated = threading.Event()
|
||||
|
||||
def register_nodes(self, nodes):
|
||||
"""Register multiple nodes at once, replacing existing registry"""
|
||||
with self._lock:
|
||||
# Clear existing registry
|
||||
self._nodes.clear()
|
||||
|
||||
# Register all new nodes
|
||||
for node in nodes:
|
||||
node_id = node['node_id']
|
||||
node_type = node.get('type', '')
|
||||
|
||||
# Convert node type name to integer
|
||||
type_id = NODE_TYPES.get(node_type, 0) # 0 for unknown types
|
||||
|
||||
# Handle null bgcolor with default color
|
||||
bgcolor = node.get('bgcolor')
|
||||
if bgcolor is None:
|
||||
bgcolor = DEFAULT_NODE_COLOR
|
||||
|
||||
self._nodes[node_id] = {
|
||||
'id': node_id,
|
||||
'bgcolor': bgcolor,
|
||||
'title': node.get('title'),
|
||||
'type': type_id,
|
||||
'type_name': node_type
|
||||
}
|
||||
|
||||
logger.debug(f"Registered {len(nodes)} nodes in registry")
|
||||
|
||||
# Signal that registry has been updated
|
||||
self._registry_updated.set()
|
||||
|
||||
def get_registry(self):
|
||||
"""Get current registry information"""
|
||||
with self._lock:
|
||||
return {
|
||||
'nodes': dict(self._nodes), # Return a copy
|
||||
'node_count': len(self._nodes)
|
||||
}
|
||||
|
||||
def clear_registry(self):
|
||||
"""Clear the entire registry"""
|
||||
with self._lock:
|
||||
self._nodes.clear()
|
||||
logger.info("Node registry cleared")
|
||||
|
||||
def wait_for_update(self, timeout=1.0):
|
||||
"""Wait for registry update with timeout"""
|
||||
self._registry_updated.clear()
|
||||
return self._registry_updated.wait(timeout)
|
||||
|
||||
# Global registry instance
|
||||
node_registry = NodeRegistry()
|
||||
|
||||
class MiscRoutes:
|
||||
"""Miscellaneous routes for various utility functions"""
|
||||
|
||||
@staticmethod
|
||||
def is_dedicated_example_images_folder(folder_path):
|
||||
"""
|
||||
Check if a folder is a dedicated example images folder.
|
||||
|
||||
A dedicated folder should either be:
|
||||
1. Empty
|
||||
2. Only contain .download_progress.json file and/or folders with valid SHA256 hash names (64 hex characters)
|
||||
|
||||
Args:
|
||||
folder_path (str): Path to the folder to check
|
||||
|
||||
Returns:
|
||||
bool: True if the folder is dedicated, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
|
||||
return False
|
||||
|
||||
items = os.listdir(folder_path)
|
||||
|
||||
# Empty folder is considered dedicated
|
||||
if not items:
|
||||
return True
|
||||
|
||||
# Check each item in the folder
|
||||
for item in items:
|
||||
item_path = os.path.join(folder_path, item)
|
||||
|
||||
# Allow .download_progress.json file
|
||||
if item == '.download_progress.json' and os.path.isfile(item_path):
|
||||
continue
|
||||
|
||||
# Allow folders with valid SHA256 hash names (64 hex characters)
|
||||
if os.path.isdir(item_path):
|
||||
# Check if the folder name is a valid SHA256 hash
|
||||
if re.match(r'^[a-fA-F0-9]{64}$', item):
|
||||
continue
|
||||
|
||||
# If we encounter anything else, it's not a dedicated folder
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if folder is dedicated: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def setup_routes(app):
|
||||
"""Register miscellaneous routes"""
|
||||
app.router.add_post('/api/settings', MiscRoutes.update_settings)
|
||||
|
||||
# Add new route for clearing cache
|
||||
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
||||
app.router.add_get('/api/lm/settings', MiscRoutes.get_settings)
|
||||
app.router.add_post('/api/lm/settings', MiscRoutes.update_settings)
|
||||
|
||||
app.router.add_get('/api/lm/health-check', lambda request: web.json_response({'status': 'ok'}))
|
||||
|
||||
app.router.add_post('/api/lm/open-file-location', MiscRoutes.open_file_location)
|
||||
|
||||
# Usage stats routes
|
||||
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
||||
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
||||
app.router.add_post('/api/lm/update-usage-stats', MiscRoutes.update_usage_stats)
|
||||
app.router.add_get('/api/lm/get-usage-stats', MiscRoutes.get_usage_stats)
|
||||
|
||||
# Lora code update endpoint
|
||||
app.router.add_post('/api/update-lora-code', MiscRoutes.update_lora_code)
|
||||
app.router.add_post('/api/lm/update-lora-code', MiscRoutes.update_lora_code)
|
||||
|
||||
# Add new route for getting trained words
|
||||
app.router.add_get('/api/trained-words', MiscRoutes.get_trained_words)
|
||||
app.router.add_get('/api/lm/trained-words', MiscRoutes.get_trained_words)
|
||||
|
||||
# Add new route for getting model example files
|
||||
app.router.add_get('/api/model-example-files', MiscRoutes.get_model_example_files)
|
||||
app.router.add_get('/api/lm/model-example-files', MiscRoutes.get_model_example_files)
|
||||
|
||||
# Node registry endpoints
|
||||
app.router.add_post('/api/lm/register-nodes', MiscRoutes.register_nodes)
|
||||
app.router.add_get('/api/lm/get-registry', MiscRoutes.get_registry)
|
||||
|
||||
# Add new route for checking if a model exists in the library
|
||||
app.router.add_get('/api/lm/check-model-exists', MiscRoutes.check_model_exists)
|
||||
|
||||
# Add routes for metadata archive database management
|
||||
app.router.add_post('/api/lm/download-metadata-archive', MiscRoutes.download_metadata_archive)
|
||||
app.router.add_post('/api/lm/remove-metadata-archive', MiscRoutes.remove_metadata_archive)
|
||||
app.router.add_get('/api/lm/metadata-archive-status', MiscRoutes.get_metadata_archive_status)
|
||||
|
||||
# Add route for checking model versions in library
|
||||
app.router.add_get('/api/lm/model-versions-status', MiscRoutes.get_model_versions_status)
|
||||
|
||||
@staticmethod
|
||||
async def clear_cache(request):
|
||||
"""Clear all cache files from the cache folder"""
|
||||
async def get_settings(request):
|
||||
"""Get application settings that should be synced to frontend"""
|
||||
try:
|
||||
# Get the cache folder path (relative to project directory)
|
||||
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
cache_folder = os.path.join(project_dir, 'cache')
|
||||
# Define keys that should be synced from backend to frontend
|
||||
sync_keys = [
|
||||
'civitai_api_key',
|
||||
'default_lora_root',
|
||||
'default_checkpoint_root',
|
||||
'default_embedding_root',
|
||||
'base_model_path_mappings',
|
||||
'download_path_templates',
|
||||
'enable_metadata_archive_db',
|
||||
'language',
|
||||
'proxy_enabled',
|
||||
'proxy_type',
|
||||
'proxy_host',
|
||||
'proxy_port',
|
||||
'proxy_username',
|
||||
'proxy_password',
|
||||
'example_images_path',
|
||||
'optimize_example_images',
|
||||
'auto_download_example_images',
|
||||
'blur_mature_content',
|
||||
'autoplay_on_hover',
|
||||
'display_density',
|
||||
'card_info_display',
|
||||
'include_trigger_words',
|
||||
'show_only_sfw',
|
||||
'compact_mode'
|
||||
]
|
||||
|
||||
# Check if cache folder exists
|
||||
if not os.path.exists(cache_folder):
|
||||
logger.info("Cache folder does not exist, nothing to clear")
|
||||
return web.json_response({'success': True, 'message': 'No cache folder found'})
|
||||
|
||||
# Get list of cache files before deleting for reporting
|
||||
cache_files = [f for f in os.listdir(cache_folder) if os.path.isfile(os.path.join(cache_folder, f))]
|
||||
deleted_files = []
|
||||
|
||||
# Delete each .msgpack file in the cache folder
|
||||
for filename in cache_files:
|
||||
if filename.endswith('.msgpack'):
|
||||
file_path = os.path.join(cache_folder, filename)
|
||||
try:
|
||||
os.remove(file_path)
|
||||
deleted_files.append(filename)
|
||||
logger.info(f"Deleted cache file: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {filename}: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Failed to delete {filename}: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# If we want to completely remove the cache folder too (optional,
|
||||
# but we'll keep the folder structure in place here)
|
||||
# shutil.rmtree(cache_folder)
|
||||
# Build response with only the keys that should be synced
|
||||
response_data = {}
|
||||
for key in sync_keys:
|
||||
value = settings.get(key)
|
||||
if value is not None:
|
||||
response_data[key] = value
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f"Successfully cleared {len(deleted_files)} cache files",
|
||||
'deleted_files': deleted_files
|
||||
'settings': response_data
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing cache files: {e}", exc_info=True)
|
||||
logger.error(f"Error getting settings: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
@@ -105,10 +228,15 @@ class MiscRoutes:
|
||||
"""Update application settings"""
|
||||
try:
|
||||
data = await request.json()
|
||||
proxy_keys = {'proxy_enabled', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password', 'proxy_type'}
|
||||
proxy_changed = False
|
||||
|
||||
# Validate and update settings
|
||||
for key, value in data.items():
|
||||
# Special handling for example_images_path - verify path exists
|
||||
if value == settings.get(key):
|
||||
# No change, skip
|
||||
continue
|
||||
# Special handling for example_images_path - verify path exists and is dedicated
|
||||
if key == 'example_images_path' and value:
|
||||
if not os.path.exists(value):
|
||||
return web.json_response({
|
||||
@@ -116,14 +244,35 @@ class MiscRoutes:
|
||||
'error': f"Path does not exist: {value}"
|
||||
})
|
||||
|
||||
# Check if folder is dedicated for example images
|
||||
if not MiscRoutes.is_dedicated_example_images_folder(value):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Please set a dedicated folder for example images."
|
||||
})
|
||||
|
||||
# Path changed - server restart required for new path to take effect
|
||||
old_path = settings.get('example_images_path')
|
||||
if old_path != value:
|
||||
logger.info(f"Example images path changed to {value} - server restart required")
|
||||
|
||||
# Save to settings
|
||||
settings.set(key, value)
|
||||
|
||||
# Handle deletion for proxy credentials
|
||||
if value == '__DELETE__' and key in ('proxy_username', 'proxy_password'):
|
||||
settings.delete(key)
|
||||
else:
|
||||
# Save to settings
|
||||
settings.set(key, value)
|
||||
|
||||
if key == 'enable_metadata_archive_db':
|
||||
await update_metadata_providers()
|
||||
|
||||
if key in proxy_keys:
|
||||
proxy_changed = True
|
||||
|
||||
if proxy_changed:
|
||||
downloader = await get_downloader()
|
||||
await downloader.refresh_session()
|
||||
|
||||
return web.json_response({'success': True})
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating settings: {e}", exc_info=True)
|
||||
@@ -403,3 +552,507 @@ class MiscRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def register_nodes(request):
|
||||
"""
|
||||
Register multiple Lora nodes at once
|
||||
|
||||
Expects a JSON body with:
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": 123,
|
||||
"bgcolor": "#535",
|
||||
"title": "Lora Loader (LoraManager)"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Validate required fields
|
||||
nodes = data.get('nodes', [])
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'nodes must be a list'
|
||||
}, status=400)
|
||||
|
||||
# Validate each node
|
||||
for i, node in enumerate(nodes):
|
||||
if not isinstance(node, dict):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Node {i} must be an object'
|
||||
}, status=400)
|
||||
|
||||
node_id = node.get('node_id')
|
||||
if node_id is None:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Node {i} missing node_id parameter'
|
||||
}, status=400)
|
||||
|
||||
# Validate node_id is an integer
|
||||
try:
|
||||
node['node_id'] = int(node_id)
|
||||
except (ValueError, TypeError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Node {i} node_id must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Register all nodes
|
||||
node_registry.register_nodes(nodes)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'{len(nodes)} nodes registered successfully'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register nodes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def get_registry(request):
|
||||
"""Get current node registry information by refreshing from frontend"""
|
||||
try:
|
||||
# Check if running in standalone mode
|
||||
if standalone_mode:
|
||||
logger.warning("Registry refresh not available in standalone mode")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Standalone Mode Active',
|
||||
'message': 'Cannot interact with ComfyUI in standalone mode.'
|
||||
}, status=503)
|
||||
|
||||
# Send message to frontend to refresh registry
|
||||
try:
|
||||
PromptServer.instance.send_sync("lora_registry_refresh", {})
|
||||
logger.debug("Sent registry refresh request to frontend")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send registry refresh message: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Communication Error',
|
||||
'message': f'Failed to communicate with ComfyUI frontend: {str(e)}'
|
||||
}, status=500)
|
||||
|
||||
# Wait for registry update with timeout
|
||||
def wait_for_registry():
|
||||
return node_registry.wait_for_update(timeout=1.0)
|
||||
|
||||
# Run the wait in a thread to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
registry_updated = await loop.run_in_executor(None, wait_for_registry)
|
||||
|
||||
if not registry_updated:
|
||||
logger.warning("Registry refresh timeout after 1 second")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Timeout Error',
|
||||
'message': 'Registry refresh timeout - ComfyUI frontend may not be responsive'
|
||||
}, status=408)
|
||||
|
||||
# Get updated registry
|
||||
registry_info = node_registry.get_registry()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': registry_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get registry: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal Error',
|
||||
'message': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def check_model_exists(request):
|
||||
"""
|
||||
Check if a model with specified modelId and optionally modelVersionId exists in the library
|
||||
|
||||
Expects query parameters:
|
||||
- modelId: int - Civitai model ID (required)
|
||||
- modelVersionId: int - Civitai model version ID (optional)
|
||||
|
||||
Returns:
|
||||
- If modelVersionId is provided: JSON with a boolean 'exists' field
|
||||
- If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library
|
||||
"""
|
||||
try:
|
||||
# Get the modelId and modelVersionId from query parameters
|
||||
model_id_str = request.query.get('modelId')
|
||||
model_version_id_str = request.query.get('modelVersionId')
|
||||
|
||||
# Validate modelId parameter (required)
|
||||
if not model_id_str:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing required parameter: modelId'
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Convert modelId to integer
|
||||
model_id = int(model_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Get all scanners
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# If modelVersionId is provided, check for specific version
|
||||
if model_version_id_str:
|
||||
try:
|
||||
model_version_id = int(model_version_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelVersionId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Check lora scanner first
|
||||
exists = False
|
||||
model_type = None
|
||||
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = 'lora'
|
||||
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = 'checkpoint'
|
||||
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = 'embedding'
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'exists': exists,
|
||||
'modelType': model_type if exists else None
|
||||
})
|
||||
|
||||
# If modelVersionId is not provided, return all version IDs for the model
|
||||
else:
|
||||
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
|
||||
checkpoint_versions = []
|
||||
embedding_versions = []
|
||||
|
||||
# 优先lora,其次checkpoint,最后embedding
|
||||
if not lora_versions:
|
||||
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
|
||||
if not lora_versions and not checkpoint_versions:
|
||||
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
|
||||
|
||||
model_type = None
|
||||
versions = []
|
||||
|
||||
if lora_versions:
|
||||
model_type = 'lora'
|
||||
versions = lora_versions
|
||||
elif checkpoint_versions:
|
||||
model_type = 'checkpoint'
|
||||
versions = checkpoint_versions
|
||||
elif embedding_versions:
|
||||
model_type = 'embedding'
|
||||
versions = embedding_versions
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'modelId': model_id,
|
||||
'modelType': model_type,
|
||||
'versions': versions
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model existence: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def download_metadata_archive(request):
|
||||
"""Download and extract the metadata archive database"""
|
||||
try:
|
||||
archive_manager = await get_metadata_archive_manager()
|
||||
|
||||
# Get the download_id from query parameters if provided
|
||||
download_id = request.query.get('download_id')
|
||||
|
||||
# Progress callback to send updates via WebSocket
|
||||
def progress_callback(stage, message):
|
||||
data = {
|
||||
'stage': stage,
|
||||
'message': message,
|
||||
'type': 'metadata_archive_download'
|
||||
}
|
||||
|
||||
if download_id:
|
||||
# Send to specific download WebSocket if download_id is provided
|
||||
asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data))
|
||||
else:
|
||||
# Fallback to general broadcast
|
||||
asyncio.create_task(ws_manager.broadcast(data))
|
||||
|
||||
# Download and extract in background
|
||||
success = await archive_manager.download_and_extract_database(progress_callback)
|
||||
|
||||
if success:
|
||||
# Update settings to enable metadata archive
|
||||
settings.set('enable_metadata_archive_db', True)
|
||||
|
||||
# Update metadata providers
|
||||
await update_metadata_providers()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': 'Metadata archive database downloaded and extracted successfully'
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to download and extract metadata archive database'
|
||||
}, status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading metadata archive: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def remove_metadata_archive(request):
|
||||
"""Remove the metadata archive database"""
|
||||
try:
|
||||
archive_manager = await get_metadata_archive_manager()
|
||||
|
||||
success = await archive_manager.remove_database()
|
||||
|
||||
if success:
|
||||
# Update settings to disable metadata archive
|
||||
settings.set('enable_metadata_archive_db', False)
|
||||
|
||||
# Update metadata providers
|
||||
await update_metadata_providers()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': 'Metadata archive database removed successfully'
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to remove metadata archive database'
|
||||
}, status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing metadata archive: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def get_metadata_archive_status(request):
|
||||
"""Get the status of metadata archive database"""
|
||||
try:
|
||||
archive_manager = await get_metadata_archive_manager()
|
||||
|
||||
is_available = archive_manager.is_database_available()
|
||||
is_enabled = settings.get('enable_metadata_archive_db', False)
|
||||
|
||||
db_size = 0
|
||||
if is_available:
|
||||
db_path = archive_manager.get_database_path()
|
||||
if db_path and os.path.exists(db_path):
|
||||
db_size = os.path.getsize(db_path)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'isAvailable': is_available,
|
||||
'isEnabled': is_enabled,
|
||||
'databaseSize': db_size,
|
||||
'databasePath': archive_manager.get_database_path() if is_available else None
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting metadata archive status: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def get_model_versions_status(request):
|
||||
"""
|
||||
Get all versions of a model from metadata provider and check their library status
|
||||
|
||||
Expects query parameters:
|
||||
- modelId: int - Civitai model ID (required)
|
||||
|
||||
Returns:
|
||||
- JSON with model type and versions list, each version includes 'inLibrary' flag
|
||||
"""
|
||||
try:
|
||||
# Get the modelId from query parameters
|
||||
model_id_str = request.query.get('modelId')
|
||||
|
||||
# Validate modelId parameter (required)
|
||||
if not model_id_str:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing required parameter: modelId'
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Convert modelId to integer
|
||||
model_id = int(model_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Get metadata provider
|
||||
metadata_provider = await get_metadata_provider()
|
||||
if not metadata_provider:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Metadata provider not available'
|
||||
}, status=503)
|
||||
|
||||
# Get model versions from metadata provider
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model not found'
|
||||
}, status=404)
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_name = response.get('name', '')
|
||||
model_type = response.get('type', '').lower()
|
||||
|
||||
# Determine scanner based on model type
|
||||
scanner = None
|
||||
normalized_type = None
|
||||
|
||||
if model_type in ['lora', 'locon', 'dora']:
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
normalized_type = 'lora'
|
||||
elif model_type == 'checkpoint':
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
normalized_type = 'checkpoint'
|
||||
elif model_type == 'textualinversion':
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
normalized_type = 'embedding'
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Model type "{model_type}" is not supported'
|
||||
}, status=400)
|
||||
|
||||
if not scanner:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Scanner for type "{normalized_type}" is not available'
|
||||
}, status=503)
|
||||
|
||||
# Get local versions from scanner
|
||||
local_versions = await scanner.get_model_versions_by_id(model_id)
|
||||
local_version_ids = set(version['versionId'] for version in local_versions)
|
||||
|
||||
# Add inLibrary flag to each version
|
||||
enriched_versions = []
|
||||
for version in versions:
|
||||
version_id = version.get('id')
|
||||
enriched_version = {
|
||||
'id': version_id,
|
||||
'name': version.get('name', ''),
|
||||
'thumbnailUrl': version.get('images')[0]['url'] if version.get('images') else None,
|
||||
'inLibrary': version_id in local_version_ids
|
||||
}
|
||||
enriched_versions.append(enriched_version)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'modelId': model_id,
|
||||
'modelName': model_name,
|
||||
'modelType': model_type,
|
||||
'versions': enriched_versions
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model versions status: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def open_file_location(request):
|
||||
"""
|
||||
Open the folder containing the specified file and select the file in the file explorer.
|
||||
|
||||
Expects a JSON request body with:
|
||||
{
|
||||
"file_path": "absolute/path/to/file"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
|
||||
if not file_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing file_path parameter'
|
||||
}, status=400)
|
||||
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'File does not exist'
|
||||
}, status=404)
|
||||
|
||||
# Open the folder and select the file
|
||||
if os.name == 'nt': # Windows
|
||||
# explorer /select,"C:\path\to\file"
|
||||
subprocess.Popen(['explorer', '/select,', file_path])
|
||||
elif os.name == 'posix':
|
||||
if sys.platform == 'darwin': # macOS
|
||||
subprocess.Popen(['open', '-R', file_path])
|
||||
else: # Linux (selecting file is not standard, just open folder)
|
||||
folder = os.path.dirname(file_path)
|
||||
subprocess.Popen(['xdg-open', folder])
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Opened folder and selected file: {file_path}'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open file location: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
99
py/routes/model_route_registrar.py
Normal file
99
py/routes/model_route_registrar.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Route registrar for model endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path_template: str
|
||||
handler_name: str
|
||||
|
||||
def build_path(self, prefix: str) -> str:
|
||||
return self.path_template.replace("{prefix}", prefix)
|
||||
|
||||
|
||||
COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/bulk-delete", "bulk_delete_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/verify-duplicates", "verify_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/move_model", "move_model"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
||||
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||
)
|
||||
|
||||
|
||||
class ModelRouteRegistrar:
|
||||
"""Bind declarative definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_common_routes(
|
||||
self,
|
||||
prefix: str,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
for definition in definitions:
|
||||
self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name])
|
||||
|
||||
def add_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
self._bind_route(method, path, handler)
|
||||
|
||||
def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None:
|
||||
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
64
py/routes/recipe_route_registrar.py
Normal file
64
py/routes/recipe_route_registrar.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Route registrar for recipe endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a recipe HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
)
|
||||
|
||||
|
||||
class RecipeRouteRegistrar:
|
||||
"""Bind declarative recipe definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||
for definition in ROUTE_DEFINITIONS:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
519
py/routes/stats_routes.py
Normal file
519
py/routes/stats_routes.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, Counter
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StatsRoutes:
|
||||
"""Route handlers for Statistics page and API endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_scanner = None
|
||||
self.checkpoint_scanner = None
|
||||
self.embedding_scanner = None
|
||||
self.usage_stats = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Only initialize usage stats if we have valid paths configured
|
||||
try:
|
||||
self.usage_stats = UsageStats()
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not initialize usage statistics: {e}")
|
||||
self.usage_stats = None
|
||||
|
||||
async def handle_stats_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /statistics request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if scanners are initializing
|
||||
lora_initializing = (
|
||||
self.lora_scanner._cache is None or
|
||||
(hasattr(self.lora_scanner, 'is_initializing') and self.lora_scanner.is_initializing())
|
||||
)
|
||||
|
||||
checkpoint_initializing = (
|
||||
self.checkpoint_scanner._cache is None or
|
||||
(hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing)
|
||||
)
|
||||
|
||||
embedding_initializing = (
|
||||
self.embedding_scanner._cache is None or
|
||||
(hasattr(self.embedding_scanner, 'is_initializing') and self.embedding_scanner.is_initializing())
|
||||
)
|
||||
|
||||
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
||||
|
||||
# 获取用户语言设置
|
||||
user_language = settings.get('language', 'en')
|
||||
|
||||
# 设置服务端i18n语言
|
||||
server_i18n.set_locale(user_language)
|
||||
|
||||
# 为模板环境添加i18n过滤器
|
||||
if not hasattr(self.template_env, '_i18n_filter_added'):
|
||||
self.template_env.filters['t'] = server_i18n.create_template_filter()
|
||||
self.template_env._i18n_filter_added = True
|
||||
|
||||
template = self.template_env.get_template('statistics.html')
|
||||
rendered = template.render(
|
||||
is_initializing=is_initializing,
|
||||
settings=settings,
|
||||
request=request,
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling statistics request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading statistics page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_collection_overview(self, request: web.Request) -> web.Response:
|
||||
"""Get collection overview statistics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get LoRA statistics
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
lora_count = len(lora_cache.raw_data)
|
||||
lora_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data)
|
||||
|
||||
# Get Checkpoint statistics
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
checkpoint_count = len(checkpoint_cache.raw_data)
|
||||
checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
|
||||
|
||||
# Get Embedding statistics
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
embedding_count = len(embedding_cache.raw_data)
|
||||
embedding_size = sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'total_models': lora_count + checkpoint_count + embedding_count,
|
||||
'lora_count': lora_count,
|
||||
'checkpoint_count': checkpoint_count,
|
||||
'embedding_count': embedding_count,
|
||||
'total_size': lora_size + checkpoint_size + embedding_size,
|
||||
'lora_size': lora_size,
|
||||
'checkpoint_size': checkpoint_size,
|
||||
'embedding_size': embedding_size,
|
||||
'total_generations': usage_data.get('total_executions', 0),
|
||||
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
|
||||
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})),
|
||||
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collection overview: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_usage_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get usage analytics data"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data for enrichment
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create hash to model mapping
|
||||
lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data}
|
||||
checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data}
|
||||
embedding_map = {emb['sha256']: emb for emb in embedding_cache.raw_data}
|
||||
|
||||
# Prepare top used models
|
||||
top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10)
|
||||
top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10)
|
||||
top_embeddings = self._get_top_used_models(usage_data.get('embeddings', {}), embedding_map, 10)
|
||||
|
||||
# Prepare usage timeline (last 30 days)
|
||||
timeline = self._get_usage_timeline(usage_data, 30)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'top_loras': top_loras,
|
||||
'top_checkpoints': top_checkpoints,
|
||||
'top_embeddings': top_embeddings,
|
||||
'usage_timeline': timeline,
|
||||
'total_executions': usage_data.get('total_executions', 0)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_base_model_distribution(self, request: web.Request) -> web.Response:
|
||||
"""Get base model distribution statistics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count by base model
|
||||
lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data)
|
||||
checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data)
|
||||
embedding_base_models = Counter(emb.get('base_model', 'Unknown') for emb in embedding_cache.raw_data)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': dict(lora_base_models),
|
||||
'checkpoints': dict(checkpoint_base_models),
|
||||
'embeddings': dict(embedding_base_models)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting base model distribution: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_tag_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get tag usage analytics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count tag frequencies
|
||||
all_tags = []
|
||||
for lora in lora_cache.raw_data:
|
||||
all_tags.extend(lora.get('tags', []))
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
all_tags.extend(cp.get('tags', []))
|
||||
for emb in embedding_cache.raw_data:
|
||||
all_tags.extend(emb.get('tags', []))
|
||||
|
||||
tag_counts = Counter(all_tags)
|
||||
|
||||
# Get top 50 tags
|
||||
top_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.most_common(50)]
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'top_tags': top_tags,
|
||||
'total_unique_tags': len(tag_counts)
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tag analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_storage_analytics(self, request: web.Request) -> web.Response:
|
||||
"""Get storage usage analytics"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create models with usage data
|
||||
lora_storage = []
|
||||
for lora in lora_cache.raw_data:
|
||||
usage_count = 0
|
||||
if lora['sha256'] in usage_data.get('loras', {}):
|
||||
usage_count = usage_data['loras'][lora['sha256']].get('total', 0)
|
||||
|
||||
lora_storage.append({
|
||||
'name': lora['model_name'],
|
||||
'size': lora.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': lora.get('folder', ''),
|
||||
'base_model': lora.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
checkpoint_storage = []
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
usage_count = 0
|
||||
if cp['sha256'] in usage_data.get('checkpoints', {}):
|
||||
usage_count = usage_data['checkpoints'][cp['sha256']].get('total', 0)
|
||||
|
||||
checkpoint_storage.append({
|
||||
'name': cp['model_name'],
|
||||
'size': cp.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': cp.get('folder', ''),
|
||||
'base_model': cp.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
embedding_storage = []
|
||||
for emb in embedding_cache.raw_data:
|
||||
usage_count = 0
|
||||
if emb['sha256'] in usage_data.get('embeddings', {}):
|
||||
usage_count = usage_data['embeddings'][emb['sha256']].get('total', 0)
|
||||
|
||||
embedding_storage.append({
|
||||
'name': emb['model_name'],
|
||||
'size': emb.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': emb.get('folder', ''),
|
||||
'base_model': emb.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
# Sort by size
|
||||
lora_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
checkpoint_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
embedding_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': lora_storage[:20], # Top 20 by size
|
||||
'checkpoints': checkpoint_storage[:20],
|
||||
'embeddings': embedding_storage[:20]
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage analytics: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_insights(self, request: web.Request) -> web.Response:
|
||||
"""Get smart insights about the collection"""
|
||||
try:
|
||||
await self.init_services()
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
insights = []
|
||||
|
||||
# Calculate unused models
|
||||
unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {}))
|
||||
unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
|
||||
unused_embeddings = self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
|
||||
total_loras = len(lora_cache.raw_data)
|
||||
total_checkpoints = len(checkpoint_cache.raw_data)
|
||||
total_embeddings = len(embedding_cache.raw_data)
|
||||
|
||||
if total_loras > 0:
|
||||
unused_lora_percent = (unused_loras / total_loras) * 100
|
||||
if unused_lora_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused LoRAs',
|
||||
'description': f'{unused_lora_percent:.1f}% of your LoRAs ({unused_loras}/{total_loras}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused models to free up storage space.'
|
||||
})
|
||||
|
||||
if total_checkpoints > 0:
|
||||
unused_checkpoint_percent = (unused_checkpoints / total_checkpoints) * 100
|
||||
if unused_checkpoint_percent > 30:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'Unused Checkpoints Detected',
|
||||
'description': f'{unused_checkpoint_percent:.1f}% of your checkpoints ({unused_checkpoints}/{total_checkpoints}) have never been used.',
|
||||
'suggestion': 'Review and consider removing checkpoints you no longer need.'
|
||||
})
|
||||
|
||||
if total_embeddings > 0:
|
||||
unused_embedding_percent = (unused_embeddings / total_embeddings) * 100
|
||||
if unused_embedding_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused Embeddings',
|
||||
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
|
||||
})
|
||||
|
||||
# Storage insights
|
||||
total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \
|
||||
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) + \
|
||||
sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
if total_size > 100 * 1024 * 1024 * 1024: # 100GB
|
||||
insights.append({
|
||||
'type': 'info',
|
||||
'title': 'Large Collection Detected',
|
||||
'description': f'Your model collection is using {self._format_size(total_size)} of storage.',
|
||||
'suggestion': 'Consider using external storage or cloud solutions for better organization.'
|
||||
})
|
||||
|
||||
# Recent activity insight
|
||||
if usage_data.get('total_executions', 0) > 100:
|
||||
insights.append({
|
||||
'type': 'success',
|
||||
'title': 'Active User',
|
||||
'description': f'You\'ve completed {usage_data["total_executions"]} generations so far!',
|
||||
'suggestion': 'Keep exploring and creating amazing content with your models.'
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'insights': insights
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting insights: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
def _count_unused_models(self, models: List[Dict], usage_data: Dict) -> int:
|
||||
"""Count models that have never been used"""
|
||||
used_hashes = set(usage_data.keys())
|
||||
unused_count = 0
|
||||
|
||||
for model in models:
|
||||
if model.get('sha256') not in used_hashes:
|
||||
unused_count += 1
|
||||
|
||||
return unused_count
|
||||
|
||||
def _get_top_used_models(self, usage_data: Dict, model_map: Dict, limit: int) -> List[Dict]:
|
||||
"""Get top used models with their metadata"""
|
||||
sorted_usage = sorted(usage_data.items(), key=lambda x: x[1].get('total', 0), reverse=True)
|
||||
|
||||
top_models = []
|
||||
for sha256, usage_info in sorted_usage[:limit]:
|
||||
if sha256 in model_map:
|
||||
model = model_map[sha256]
|
||||
top_models.append({
|
||||
'name': model['model_name'],
|
||||
'usage_count': usage_info.get('total', 0),
|
||||
'base_model': model.get('base_model', 'Unknown'),
|
||||
'preview_url': config.get_preview_static_url(model.get('preview_url', '')),
|
||||
'folder': model.get('folder', '')
|
||||
})
|
||||
|
||||
return top_models
|
||||
|
||||
def _get_usage_timeline(self, usage_data: Dict, days: int) -> List[Dict]:
|
||||
"""Get usage timeline for the past N days"""
|
||||
timeline = []
|
||||
today = datetime.now()
|
||||
|
||||
for i in range(days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.strftime('%Y-%m-%d')
|
||||
|
||||
lora_usage = 0
|
||||
checkpoint_usage = 0
|
||||
embedding_usage = 0
|
||||
|
||||
# Count usage for this date
|
||||
for model_usage in usage_data.get('loras', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
lora_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('checkpoints', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
checkpoint_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('embeddings', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
embedding_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
timeline.append({
|
||||
'date': date_str,
|
||||
'lora_usage': lora_usage,
|
||||
'checkpoint_usage': checkpoint_usage,
|
||||
'embedding_usage': embedding_usage,
|
||||
'total_usage': lora_usage + checkpoint_usage + embedding_usage
|
||||
})
|
||||
|
||||
return list(reversed(timeline)) # Oldest to newest
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f} {unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f} PB"
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
|
||||
# Register page route
|
||||
app.router.add_get('/statistics', self.handle_stats_page)
|
||||
|
||||
# Register API routes
|
||||
app.router.add_get('/api/lm/stats/collection-overview', self.get_collection_overview)
|
||||
app.router.add_get('/api/lm/stats/usage-analytics', self.get_usage_analytics)
|
||||
app.router.add_get('/api/lm/stats/base-model-distribution', self.get_base_model_distribution)
|
||||
app.router.add_get('/api/lm/stats/tag-analytics', self.get_tag_analytics)
|
||||
app.router.add_get('/api/lm/stats/storage-analytics', self.get_storage_analytics)
|
||||
app.router.add_get('/api/lm/stats/insights', self.get_insights)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
@@ -1,9 +1,13 @@
|
||||
import os
|
||||
import aiohttp
|
||||
import logging
|
||||
import toml
|
||||
import git
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, List
|
||||
from ..services.downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,7 +17,9 @@ class UpdateRoutes:
|
||||
@staticmethod
|
||||
def setup_routes(app):
|
||||
"""Register update check routes"""
|
||||
app.router.add_get('/loras/api/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/lm/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/lm/version-info', UpdateRoutes.get_version_info)
|
||||
app.router.add_post('/api/lm/perform-update', UpdateRoutes.perform_update)
|
||||
|
||||
@staticmethod
|
||||
async def check_updates(request):
|
||||
@@ -22,24 +28,39 @@ class UpdateRoutes:
|
||||
Returns update status and version information
|
||||
"""
|
||||
try:
|
||||
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version()
|
||||
|
||||
# Get git info (commit hash, branch)
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
|
||||
# Fetch remote version from GitHub
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
if nightly:
|
||||
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||
else:
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
|
||||
# Compare versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
if nightly:
|
||||
# For nightly, compare commit hashes
|
||||
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||
else:
|
||||
# For stable, compare semantic versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'current_version': local_version,
|
||||
'latest_version': remote_version,
|
||||
'update_available': update_available,
|
||||
'changelog': changelog
|
||||
'changelog': changelog,
|
||||
'git_info': git_info,
|
||||
'nightly': nightly
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -48,6 +69,301 @@ class UpdateRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def get_version_info(request):
|
||||
"""
|
||||
Returns the current version in the format 'version-short_hash'
|
||||
"""
|
||||
try:
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version().replace('v', '')
|
||||
|
||||
# Get git info (commit hash, branch)
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
short_hash = git_info['short_hash']
|
||||
|
||||
# Format: version-short_hash
|
||||
version_string = f"{local_version}-{short_hash}"
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'version': version_string
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get version info: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def perform_update(request):
|
||||
"""
|
||||
Perform Git-based update to latest release tag or main branch.
|
||||
If .git is missing, fallback to ZIP download.
|
||||
"""
|
||||
try:
|
||||
body = await request.json() if request.has_body else {}
|
||||
nightly = body.get('nightly', False)
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
settings_path = os.path.join(plugin_root, 'settings.json')
|
||||
settings_backup = None
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings_backup = f.read()
|
||||
logger.info("Backed up settings.json")
|
||||
|
||||
git_folder = os.path.join(plugin_root, '.git')
|
||||
if os.path.exists(git_folder):
|
||||
# Git update
|
||||
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||
else:
|
||||
# Fallback: Download ZIP and replace files
|
||||
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
|
||||
|
||||
if settings_backup and success:
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
f.write(settings_backup)
|
||||
logger.info("Restored settings.json")
|
||||
|
||||
if success:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully updated to {new_version}',
|
||||
'new_version': new_version
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to complete update'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download latest release ZIP from GitHub and replace plugin files.
|
||||
Skips settings.json and civitai folder. Writes extracted file list to .tracking.
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Get release info
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
github_api,
|
||||
use_auth=False
|
||||
)
|
||||
if not success:
|
||||
logger.error(f"Failed to fetch release info: {data}")
|
||||
return False, ""
|
||||
|
||||
zip_url = data.get("zipball_url")
|
||||
version = data.get("tag_name", "unknown")
|
||||
|
||||
# Download ZIP to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
tmp_zip_path = tmp_zip.name
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
url=zip_url,
|
||||
save_path=tmp_zip_path,
|
||||
use_auth=False,
|
||||
allow_resume=False
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"Failed to download ZIP: {result}")
|
||||
return False, ""
|
||||
|
||||
zip_path = tmp_zip_path
|
||||
|
||||
# Skip both settings.json and civitai folder
|
||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai'])
|
||||
|
||||
# Extract ZIP to temp dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
# Find extracted folder (GitHub ZIP contains a root folder)
|
||||
extracted_root = next(os.scandir(tmp_dir)).path
|
||||
|
||||
# Copy files, skipping settings.json and civitai folder
|
||||
for item in os.listdir(extracted_root):
|
||||
if item == 'settings.json' or item == 'civitai':
|
||||
continue
|
||||
src = os.path.join(extracted_root, item)
|
||||
dst = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(src):
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Write .tracking file: list all files under extracted_root, relative to extracted_root
|
||||
# for ComfyUI Manager to work properly
|
||||
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
||||
tracking_files = []
|
||||
for root, dirs, files in os.walk(extracted_root):
|
||||
# Skip civitai folder and its contents
|
||||
rel_root = os.path.relpath(root, extracted_root)
|
||||
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
|
||||
continue
|
||||
for file in files:
|
||||
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
||||
# Skip settings.json and any file under civitai
|
||||
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
|
||||
continue
|
||||
tracking_files.append(rel_path.replace("\\", "/"))
|
||||
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
||||
file.write('\n'.join(tracking_files))
|
||||
|
||||
os.remove(zip_path)
|
||||
logger.info(f"Updated plugin via ZIP to {version}")
|
||||
return True, version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
||||
return False, ""
|
||||
|
||||
def _clean_plugin_folder(plugin_root, skip_files=None):
|
||||
skip_files = skip_files or []
|
||||
for item in os.listdir(plugin_root):
|
||||
if item in skip_files:
|
||||
continue
|
||||
path = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
os.remove(path)
|
||||
|
||||
@staticmethod
|
||||
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
Fetch latest commit from main branch
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
|
||||
# Use GitHub API to fetch the latest commit from main branch
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch GitHub commit: {data}")
|
||||
return "main", []
|
||||
|
||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||
commit_message = data.get('commit', {}).get('message', '')
|
||||
|
||||
# Format as "main-{short_hash}"
|
||||
version = f"main-{commit_sha}"
|
||||
|
||||
# Use commit message as changelog
|
||||
changelog = [commit_message] if commit_message else []
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||
return "main", []
|
||||
|
||||
@staticmethod
|
||||
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||
"""
|
||||
Compare local commit hash with remote main branch
|
||||
"""
|
||||
try:
|
||||
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||
if local_hash == 'unknown':
|
||||
return True # Assume update available if we can't get local hash
|
||||
|
||||
# Extract remote hash from version string (format: "main-{hash}")
|
||||
if '-' in remote_version:
|
||||
remote_hash = remote_version.split('-')[-1]
|
||||
return local_hash != remote_hash
|
||||
|
||||
return True # Default to update available
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing nightly versions: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform Git-based update using GitPython
|
||||
|
||||
Args:
|
||||
plugin_root: Path to the plugin root directory
|
||||
nightly: Whether to update to main branch or latest release
|
||||
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
|
||||
# Fetch latest changes
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
|
||||
if nightly:
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
if main_branch not in [branch.name for branch in repo.branches]:
|
||||
# Create local main branch if it doesn't exist
|
||||
repo.create_head(main_branch, origin.refs.main)
|
||||
|
||||
repo.heads[main_branch].checkout()
|
||||
origin.pull(main_branch)
|
||||
|
||||
# Get new commit hash
|
||||
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||
|
||||
else:
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
if not tags:
|
||||
logger.error("No tags found in repository")
|
||||
return False, ""
|
||||
|
||||
latest_tag = tags[0]
|
||||
|
||||
# Checkout to latest tag
|
||||
repo.git.checkout(latest_tag.name)
|
||||
|
||||
new_version = latest_tag.name
|
||||
|
||||
logger.info(f"Successfully updated to {new_version}")
|
||||
return True, new_version
|
||||
|
||||
except git.exc.GitError as e:
|
||||
logger.error(f"Git error during update: {e}")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Git update: {e}")
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_local_version() -> str:
|
||||
@@ -72,6 +388,35 @@ class UpdateRoutes:
|
||||
logger.error(f"Failed to get local version: {e}", exc_info=True)
|
||||
return "v0.0.0"
|
||||
|
||||
@staticmethod
|
||||
def _get_git_info() -> Dict[str, str]:
|
||||
"""Get Git repository information"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
git_info = {
|
||||
'commit_hash': 'unknown',
|
||||
'short_hash': 'stable',
|
||||
'branch': 'unknown',
|
||||
'commit_date': 'unknown'
|
||||
}
|
||||
|
||||
try:
|
||||
# Check if we're in a git repository
|
||||
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
||||
return git_info
|
||||
|
||||
repo = git.Repo(plugin_root)
|
||||
commit = repo.head.commit
|
||||
git_info['commit_hash'] = commit.hexsha
|
||||
git_info['short_hash'] = commit.hexsha[:7]
|
||||
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
|
||||
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting git info: {e}")
|
||||
|
||||
return git_info
|
||||
|
||||
@staticmethod
|
||||
async def _get_remote_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
@@ -86,22 +431,22 @@ class UpdateRoutes:
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch GitHub release: {response.status}")
|
||||
return "v0.0.0", []
|
||||
|
||||
data = await response.json()
|
||||
version = data.get('tag_name', '')
|
||||
if not version.startswith('v'):
|
||||
version = f"v{version}"
|
||||
|
||||
# Extract changelog from release notes
|
||||
body = data.get('body', '')
|
||||
changelog = UpdateRoutes._parse_changelog(body)
|
||||
|
||||
return version, changelog
|
||||
downloader = await get_downloader()
|
||||
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
||||
|
||||
if not success:
|
||||
logger.warning(f"Failed to fetch GitHub release: {data}")
|
||||
return "v0.0.0", []
|
||||
|
||||
version = data.get('tag_name', '')
|
||||
if not version.startswith('v'):
|
||||
version = f"v{version}"
|
||||
|
||||
# Extract changelog from release notes
|
||||
body = data.get('body', '')
|
||||
changelog = UpdateRoutes._parse_changelog(body)
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
||||
|
||||
375
py/services/base_model_service.py
Normal file
375
py/services/base_model_service.py
Normal file
@@ -0,0 +1,375 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||
from .settings_manager import settings as default_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
scanner,
|
||||
metadata_class: Type[BaseModelMetadata],
|
||||
*,
|
||||
cache_repository: Optional[ModelCacheRepository] = None,
|
||||
filter_set: Optional[ModelFilterSet] = None,
|
||||
search_strategy: Optional[SearchStrategy] = None,
|
||||
settings_provider: Optional[SettingsProvider] = None,
|
||||
):
|
||||
"""Initialize the service.
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.).
|
||||
scanner: Model scanner instance.
|
||||
metadata_class: Metadata class for this model type.
|
||||
cache_repository: Custom repository for cache access (primarily for tests).
|
||||
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||
search_strategy: Search component for fuzzy/text matching.
|
||||
settings_provider: Settings object; defaults to the global settings manager.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
self.settings = settings_provider or default_settings
|
||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||
self.search_strategy = search_strategy or SearchStrategy()
|
||||
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_by: str = 'name',
|
||||
folder: str = None,
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated and filtered model data"""
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
filtered_data = await self._apply_common_filters(
|
||||
sorted_data,
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
)
|
||||
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data,
|
||||
search,
|
||||
fuzzy_search,
|
||||
search_options,
|
||||
)
|
||||
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
criteria = FilterCriteria(
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
)
|
||||
return self.filter_set.apply(data, criteria)
|
||||
|
||||
async def _apply_search_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
search: str,
|
||||
fuzzy_search: bool,
|
||||
search_options: dict,
|
||||
) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images", "customImages", "creator"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||
"""Get hierarchical folder tree for a specific model root"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build tree structure from folders
|
||||
tree = {}
|
||||
|
||||
for folder in cache.folders:
|
||||
# Check if this folder belongs to the specified model root
|
||||
folder_belongs_to_root = False
|
||||
for root in self.scanner.get_model_roots():
|
||||
if root == model_root:
|
||||
folder_belongs_to_root = True
|
||||
break
|
||||
|
||||
if not folder_belongs_to_root:
|
||||
continue
|
||||
|
||||
# Split folder path into components
|
||||
parts = folder.split('/') if folder else []
|
||||
current_level = tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return tree
|
||||
|
||||
async def get_unified_folder_tree(self) -> Dict:
|
||||
"""Get unified folder tree across all model roots"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build unified tree structure by analyzing all relative paths
|
||||
unified_tree = {}
|
||||
|
||||
# Get all model roots for path normalization
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for folder in cache.folders:
|
||||
if not folder: # Skip empty folders
|
||||
continue
|
||||
|
||||
# Find which root this folder belongs to by checking the actual file paths
|
||||
# This is a simplified approach - we'll use the folder as-is since it should already be relative
|
||||
relative_path = folder
|
||||
|
||||
# Split folder path into components
|
||||
parts = relative_path.split('/')
|
||||
current_level = unified_tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return unified_tree
|
||||
|
||||
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
return model.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
preview_url = model.get('preview_url')
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return '/loras_static/images/no-preview.png'
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
civitai_data = model.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||
"""Get filtered CivitAI metadata for a model by file path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return self.filter_civitai_data(model.get("civitai", {}))
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||
"""Get model description by file path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return model.get('modelDescription', '')
|
||||
|
||||
return None
|
||||
|
||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
matching_paths = []
|
||||
search_lower = search_term.lower()
|
||||
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get('file_path', '')
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Calculate relative path from model root
|
||||
relative_path = None
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root)
|
||||
normalized_file = os.path.normpath(file_path)
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
# Remove root and leading separator to get relative path
|
||||
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
|
||||
break
|
||||
|
||||
if relative_path and search_lower in relative_path.lower():
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
# Sort by relevance (exact matches first, then by length)
|
||||
matching_paths.sort(key=lambda x: (
|
||||
not x.lower().startswith(search_lower), # Exact prefix matches first
|
||||
len(x), # Then by length (shorter first)
|
||||
x.lower() # Then alphabetically
|
||||
))
|
||||
|
||||
return matching_paths[:limit]
|
||||
@@ -1,131 +1,34 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
self._checkpoint_roots = self._init_checkpoint_roots()
|
||||
self._initialized = True
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
if hasattr(metadata, "model_type"):
|
||||
if root_path in config.checkpoints_roots:
|
||||
metadata.model_type = "checkpoint"
|
||||
elif root_path in config.unet_roots:
|
||||
metadata.model_type = "diffusion_model"
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _init_checkpoint_roots(self) -> List[str]:
|
||||
"""Initialize checkpoint roots from ComfyUI settings"""
|
||||
# Get both checkpoint and diffusion_models paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
|
||||
|
||||
# Combine, normalize and deduplicate paths
|
||||
all_paths = set()
|
||||
for path in checkpoint_paths + diffusion_paths:
|
||||
if os.path.exists(path):
|
||||
norm_path = path.replace(os.sep, "/")
|
||||
all_paths.add(norm_path)
|
||||
|
||||
# Sort for consistent order
|
||||
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
|
||||
|
||||
return sorted_paths
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return self._checkpoint_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all checkpoint directories and return metadata"""
|
||||
all_checkpoints = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for root in self._checkpoint_roots:
|
||||
task = asyncio.create_task(self._scan_directory(root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
checkpoints = await task
|
||||
all_checkpoints.extend(checkpoints)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
||||
|
||||
return all_checkpoints
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a directory for checkpoint files"""
|
||||
checkpoints = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
# Check if file has supported extension
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, checkpoints)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return checkpoints
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
||||
"""Process a single checkpoint file and add to results"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
checkpoints.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
return config.base_models_roots
|
||||
49
py/services/checkpoint_service.py
Normal file
49
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,13 +1,10 @@
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import logging
|
||||
import asyncio
|
||||
from email.parser import Parser
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from urllib.parse import unquote
|
||||
from ..utils.models import LoraMetadata
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +18,11 @@ class CivitaiClient:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
|
||||
# Register this client as a metadata provider
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
@@ -30,81 +32,9 @@ class CivitaiClient:
|
||||
self._initialized = True
|
||||
|
||||
self.base_url = "https://civitai.com/api/v1"
|
||||
self.headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||
}
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
# Set default buffer size to 1MB for higher throughput
|
||||
self.chunk_size = 1024 * 1024
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Lazy initialize the session"""
|
||||
if self._session is None:
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=3, # Further reduced from 5 to 3
|
||||
ttl_dns_cache=0, # Disabled DNS caching completely
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
trust_env = True # Allow using system environment proxy settings
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=trust_env,
|
||||
timeout=timeout
|
||||
)
|
||||
self._session_created_at = datetime.now()
|
||||
return self._session
|
||||
|
||||
async def _ensure_fresh_session(self):
|
||||
"""Refresh session if it's been open too long"""
|
||||
if self._session is not None:
|
||||
if not hasattr(self, '_session_created_at') or \
|
||||
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
|
||||
await self.close()
|
||||
self._session = None
|
||||
|
||||
return await self.session
|
||||
|
||||
def _parse_content_disposition(self, header: str) -> str:
|
||||
"""Parse filename from content-disposition header"""
|
||||
if not header:
|
||||
return None
|
||||
|
||||
# Handle quoted filenames
|
||||
if 'filename="' in header:
|
||||
start = header.index('filename="') + 10
|
||||
end = header.index('"', start)
|
||||
return unquote(header[start:end])
|
||||
|
||||
# Fallback to original parsing
|
||||
disposition = Parser().parsestr(f'Content-Disposition: {header}')
|
||||
filename = disposition.get_param('filename')
|
||||
if filename:
|
||||
return unquote(filename)
|
||||
return None
|
||||
|
||||
def _get_request_headers(self) -> dict:
|
||||
"""Get request headers with optional API key"""
|
||||
headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
from .settings_manager import settings
|
||||
api_key = settings.get('civitai_api_key')
|
||||
if (api_key):
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
return headers
|
||||
|
||||
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with content-disposition support and progress tracking
|
||||
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
@@ -115,115 +45,231 @@ class CivitaiClient:
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
logger.debug(f"Resolving DNS for: {url}")
|
||||
session = await self._ensure_fresh_session()
|
||||
downloader = await get_downloader()
|
||||
save_path = os.path.join(save_dir, default_filename)
|
||||
|
||||
# Use unified downloader with CivitAI authentication
|
||||
success, result = await downloader.download_file(
|
||||
url=url,
|
||||
save_path=save_path,
|
||||
progress_callback=progress_callback,
|
||||
use_auth=True, # Enable CivitAI authentication
|
||||
allow_resume=True
|
||||
)
|
||||
|
||||
return success, result
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
try:
|
||||
headers = self._get_request_headers()
|
||||
|
||||
# Add Range header to allow resumable downloads
|
||||
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
||||
|
||||
logger.debug(f"Starting download from: {url}")
|
||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
||||
if response.status != 200:
|
||||
# Handle 401 unauthorized responses
|
||||
if response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Get model ID from version data
|
||||
model_id = result.get('modelId')
|
||||
if model_id:
|
||||
# Fetch additional model metadata
|
||||
success_model, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success_model:
|
||||
# Enrich version_info with model data
|
||||
result['model']['description'] = data.get("description")
|
||||
result['model']['tags'] = data.get("tags", [])
|
||||
|
||||
return False, "Invalid or missing CivitAI API key, or early access restriction."
|
||||
|
||||
# Handle other client errors that might be permission-related
|
||||
if response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
|
||||
# Generic error response for other status codes
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
# Get filename from content-disposition header
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
filename = self._parse_content_disposition(content_disposition)
|
||||
if not filename:
|
||||
filename = default_filename
|
||||
# Add creator from model data
|
||||
result['creator'] = data.get("creator")
|
||||
|
||||
save_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Get total file size for progress calculation
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
current_size = 0
|
||||
last_progress_report_time = datetime.now()
|
||||
|
||||
# Stream download to file with progress updates using larger buffer
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 0.5:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return True, save_path
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Network error during download: {e}")
|
||||
return False, f"Network error: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
return None, "Model not found"
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}")
|
||||
return None, str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"API Error: {str(e)}")
|
||||
return None
|
||||
return None, str(e)
|
||||
|
||||
async def download_preview_image(self, image_url: str, save_path: str):
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
async with session.get(image_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
image_url,
|
||||
use_auth=False # Preview images don't need auth
|
||||
)
|
||||
if success:
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Download Error: {str(e)}")
|
||||
logger.error(f"Download Error: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
||||
"""Get all versions of a model with local availability info"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session() # Use fresh session
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
data = await response.json()
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Also return model type along with versions
|
||||
return {
|
||||
'modelVersions': data.get('modelVersions', []),
|
||||
'type': data.get('type', '')
|
||||
'modelVersions': result.get('modelVersions', []),
|
||||
'type': result.get('type', ''),
|
||||
'name': result.get('name', '')
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID (optional if version_id is provided)
|
||||
version_id: Optional specific version ID to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The model version data with additional fields or None if not found
|
||||
"""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
# Now get the model data for additional metadata
|
||||
success, model_data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Enrich version with model data
|
||||
version['model']['description'] = model_data.get("description")
|
||||
version['model']['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
return version
|
||||
|
||||
# Case 2: model_id is provided (with or without version_id)
|
||||
elif model_id is not None:
|
||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
model_versions = data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
|
||||
# Step 2: Determine the target version entry to use
|
||||
target_version = None
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
if target_version is None:
|
||||
target_version = model_versions[0]
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
|
||||
# Step 3: Get detailed version info using the SHA256 hash
|
||||
model_hash = None
|
||||
for file_info in target_version.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
model_hash = file_info.get('hashes', {}).get('SHA256')
|
||||
if model_hash:
|
||||
break
|
||||
|
||||
version = None
|
||||
if model_hash:
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
|
||||
)
|
||||
version = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = copy.deepcopy(target_version)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': data.get('name'),
|
||||
'type': data.get('type'),
|
||||
'nsfw': data.get('nsfw'),
|
||||
'poi': data.get('poi')
|
||||
}
|
||||
|
||||
# Step 4: Enrich version_info with model data
|
||||
# Add description and tags from model data
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
model_info['description'] = data.get("description")
|
||||
model_info['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
|
||||
return version
|
||||
|
||||
# Case 3: Neither model_id nor version_id provided
|
||||
else:
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from Civitai
|
||||
@@ -237,116 +283,34 @@ class CivitaiClient:
|
||||
- An error message if there was an error, or None on success
|
||||
"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/model-versions/{version_id}"
|
||||
headers = self._get_request_headers()
|
||||
|
||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
||||
return await response.json(), None
|
||||
|
||||
# Handle specific error cases
|
||||
if response.status == 404:
|
||||
# Try to parse the error message
|
||||
try:
|
||||
error_data = await response.json()
|
||||
error_msg = error_data.get('error', f"Model not found (status 404)")
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
return None, error_msg
|
||||
except:
|
||||
return None, "Model not found (status 404)"
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
|
||||
return None, f"Failed to fetch model info (status {response.status})"
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
return None, error_msg
|
||||
|
||||
# Other error cases
|
||||
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
||||
return None, str(result)
|
||||
except Exception as e:
|
||||
error_msg = f"Error fetching model version info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None, error_msg
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description, tags, and creator info) from Civitai API
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], int]: A tuple containing:
|
||||
- A dictionary with model metadata or None if not found
|
||||
- The HTTP status code from the request
|
||||
"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
headers = self._get_request_headers()
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
status_code = response.status
|
||||
|
||||
if status_code != 200:
|
||||
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
|
||||
return None, status_code
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": data.get("description") or "No model description available",
|
||||
"tags": data.get("tags", []),
|
||||
"creator": {
|
||||
"username": data.get("creator", {}).get("username"),
|
||||
"image": data.get("creator", {}).get("image")
|
||||
}
|
||||
}
|
||||
|
||||
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
||||
return metadata, status_code
|
||||
else:
|
||||
logger.warning(f"No metadata found for model {model_id}")
|
||||
return None, status_code
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||
return None, 0
|
||||
|
||||
# Keep old method for backward compatibility, delegating to the new one
|
||||
async def get_model_description(self, model_id: str) -> Optional[str]:
|
||||
"""Fetch the model description from Civitai API (Legacy method)"""
|
||||
metadata, _ = await self.get_model_metadata(model_id)
|
||||
return metadata.get("description") if metadata else None
|
||||
|
||||
async def close(self):
|
||||
"""Close the session if it exists"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
||||
"""Get hash from Civitai API"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
if not session:
|
||||
return None
|
||||
|
||||
version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}")
|
||||
|
||||
if not version_info or not version_info.json().get('files'):
|
||||
return None
|
||||
|
||||
# Get hash from the first file
|
||||
for file_info in version_info.json().get('files', []):
|
||||
if file_info.get('hashes', {}).get('SHA256'):
|
||||
# Convert hash to lowercase to standardize
|
||||
hash_value = file_info['hashes']['SHA256'].lower()
|
||||
return hash_value
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting hash from Civitai: {e}")
|
||||
return None
|
||||
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
@@ -357,22 +321,25 @@ class CivitaiClient:
|
||||
Optional[Dict]: The image data or None if not found
|
||||
"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
headers = self._get_request_headers()
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||
|
||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if data and "items" in data and len(data["items"]) > 0:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return data["items"][0]
|
||||
logger.warning(f"No image found with ID: {image_id}")
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})")
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if success:
|
||||
if result and "items" in result and len(result["items"]) > 0:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return result["items"][0]
|
||||
logger.warning(f"No image found with ID: {image_id}")
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
100
py/services/download_coordinator.py
Normal file
100
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Service wrapper for coordinating download lifecycle events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadCoordinator:
|
||||
"""Manage download scheduling, cancellation and introspection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager_factory: Callable[[], Awaitable],
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._download_manager_factory = download_manager_factory
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download using the provided payload."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
|
||||
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||
payload.setdefault("download_id", download_id)
|
||||
|
||||
async def progress_callback(progress: Any) -> None:
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "progress",
|
||||
"progress": progress,
|
||||
"download_id": download_id,
|
||||
},
|
||||
)
|
||||
|
||||
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||
model_version_id = self._parse_optional_int(
|
||||
payload.get("model_version_id"), "model_version_id"
|
||||
)
|
||||
|
||||
if model_id is None and model_version_id is None:
|
||||
raise ValueError(
|
||||
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||
)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=payload.get("model_root"),
|
||||
relative_path=payload.get("relative_path", ""),
|
||||
use_default_paths=payload.get("use_default_paths", False),
|
||||
progress_callback=progress_callback,
|
||||
download_id=download_id,
|
||||
source=payload.get("source"),
|
||||
)
|
||||
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Cancel an active download and emit a broadcast event."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "cancelled",
|
||||
"progress": 0,
|
||||
"download_id": download_id,
|
||||
"message": "Download cancelled by user",
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||
"""Return the active download map from the underlying manager."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
return await download_manager.get_active_downloads()
|
||||
|
||||
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||
"""Parse an optional integer from user input."""
|
||||
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from .downloader import get_downloader
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -31,21 +36,10 @@ class DownloadManager:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self._civitai_client = None # Will be lazily initialized
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
if self._civitai_client is None:
|
||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
return self._civitai_client
|
||||
|
||||
async def _get_lora_monitor(self):
|
||||
"""Get the lora file monitor from registry"""
|
||||
return await ServiceRegistry.get_lora_monitor()
|
||||
|
||||
async def _get_checkpoint_monitor(self):
|
||||
"""Get the checkpoint file monitor from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_monitor()
|
||||
# Add download management
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
@@ -54,56 +48,225 @@ class DownloadManager:
|
||||
async def _get_checkpoint_scanner(self):
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
relative_path: str = '', progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Download model from Civitai
|
||||
|
||||
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
||||
save_dir: str = None, relative_path: str = '',
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
download_id: str = None, source: str = None) -> Dict:
|
||||
"""Download model from Civitai with task tracking and concurrency control
|
||||
|
||||
Args:
|
||||
download_url: Direct download URL for the model
|
||||
model_hash: SHA256 hash of the model
|
||||
model_version_id: Civitai model version ID
|
||||
save_dir: Directory to save the model to
|
||||
model_id: Civitai model ID (optional if model_version_id is provided)
|
||||
model_version_id: Civitai model version ID (optional if model_id is provided)
|
||||
save_dir: Directory to save the model
|
||||
relative_path: Relative path within save_dir
|
||||
progress_callback: Callback function for progress updates
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
use_default_paths: Flag to use default paths
|
||||
download_id: Unique identifier for this download task
|
||||
source: Optional source parameter to specify metadata provider
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
"""
|
||||
# Validate that at least one identifier is provided
|
||||
if not model_id and not model_version_id:
|
||||
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
|
||||
|
||||
# Use provided download_id or generate new one
|
||||
task_id = download_id or str(uuid.uuid4())
|
||||
|
||||
# Register download task in tracking dict
|
||||
self._active_downloads[task_id] = {
|
||||
'model_id': model_id,
|
||||
'model_version_id': model_version_id,
|
||||
'progress': 0,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# Create tracking task
|
||||
download_task = asyncio.create_task(
|
||||
self._download_with_semaphore(
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths, source
|
||||
)
|
||||
)
|
||||
|
||||
# Store task for tracking and cancellation
|
||||
self._download_tasks[task_id] = download_task
|
||||
|
||||
try:
|
||||
# Wait for download to complete
|
||||
result = await download_task
|
||||
result['download_id'] = task_id # Include download_id in result
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
|
||||
finally:
|
||||
# Clean up task reference
|
||||
if task_id in self._download_tasks:
|
||||
del self._download_tasks[task_id]
|
||||
|
||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||
save_dir: str, relative_path: str,
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
source: str = None):
|
||||
"""Execute download with semaphore to limit concurrency"""
|
||||
# Update status to waiting
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'waiting'
|
||||
|
||||
# Wrap progress callback to track progress in active_downloads
|
||||
original_callback = progress_callback
|
||||
async def tracking_callback(progress):
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['progress'] = progress
|
||||
if original_callback:
|
||||
await original_callback(progress)
|
||||
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
async with self._download_semaphore:
|
||||
# Update status to downloading
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'downloading'
|
||||
|
||||
# Use original download implementation
|
||||
try:
|
||||
# Check for cancellation before starting
|
||||
if asyncio.current_task().cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
result = await self._execute_original_download(
|
||||
model_id, model_version_id, save_dir,
|
||||
relative_path, tracking_callback, use_default_paths,
|
||||
task_id, source
|
||||
)
|
||||
|
||||
# Update status based on result
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
||||
if not result['success']:
|
||||
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
||||
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'cancelled'
|
||||
logger.info(f"Download cancelled for task {task_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle other errors
|
||||
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'failed'
|
||||
self._active_downloads[task_id]['error'] = str(e)
|
||||
return {'success': False, 'error': str(e)}
|
||||
finally:
|
||||
# Schedule cleanup of download record after delay
|
||||
asyncio.create_task(self._cleanup_download_record(task_id))
|
||||
|
||||
async def _cleanup_download_record(self, task_id: str):
|
||||
"""Keep completed downloads in history for a short time"""
|
||||
await asyncio.sleep(600) # Keep for 10 minutes
|
||||
if task_id in self._active_downloads:
|
||||
del self._active_downloads[task_id]
|
||||
|
||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths,
|
||||
download_id=None, source=None):
|
||||
"""Wrapper for original download_from_civitai implementation"""
|
||||
try:
|
||||
# Check if model version already exists in library
|
||||
if model_version_id is not None:
|
||||
# Check both scanners
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Check lora scanner first
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
|
||||
# Check checkpoint scanner
|
||||
if await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
|
||||
# Check embedding scanner
|
||||
if await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Get metadata provider based on source parameter
|
||||
if source == 'civarchive':
|
||||
from .metadata_service import get_metadata_provider
|
||||
metadata_provider = await get_metadata_provider('civarchive')
|
||||
else:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
||||
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||
if model_type_from_info == 'checkpoint':
|
||||
model_type = 'checkpoint'
|
||||
elif model_type_from_info in VALID_LORA_TYPES:
|
||||
model_type = 'lora'
|
||||
elif model_type_from_info == 'textualinversion':
|
||||
model_type = 'embedding'
|
||||
else:
|
||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||
|
||||
# Case 2: model_version_id was None, check after getting version_info
|
||||
if model_version_id is None:
|
||||
version_id = version_info.get('id')
|
||||
|
||||
if model_type == 'lora':
|
||||
# Check lora scanner
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
if await lora_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
elif model_type == 'checkpoint':
|
||||
# Check checkpoint scanner
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
if await checkpoint_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
elif model_type == 'embedding':
|
||||
# Embeddings are not checked in scanners, but we can still check if it exists
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
if await embedding_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
# Set save_dir based on model type
|
||||
if model_type == 'checkpoint':
|
||||
default_path = settings.get('default_checkpoint_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'lora':
|
||||
default_path = settings.get('default_lora_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'embedding':
|
||||
default_path = settings.get('default_embedding_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
||||
save_dir = default_path
|
||||
|
||||
# Calculate relative path using template
|
||||
relative_path = self._calculate_relative_path(version_info, model_type)
|
||||
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Get civitai client
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = None
|
||||
error_msg = None
|
||||
|
||||
if model_hash:
|
||||
# Get model by hash
|
||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
||||
elif model_version_id:
|
||||
# Use model version ID directly
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
||||
elif download_url:
|
||||
# Extract version ID from download URL
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
|
||||
if not version_info:
|
||||
if error_msg and "model not found" in error_msg.lower():
|
||||
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
|
||||
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
|
||||
|
||||
# Check if this is an early access model
|
||||
if version_info.get('earlyAccessEndsAt'):
|
||||
early_access_date = version_info.get('earlyAccessEndsAt', '')
|
||||
@@ -112,9 +275,9 @@ class DownloadManager:
|
||||
from datetime import datetime
|
||||
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
|
||||
formatted_date = date_obj.strftime('%Y-%m-%d')
|
||||
early_access_msg = f"This model requires early access payment (until {formatted_date}). "
|
||||
early_access_msg = f"This model requires payment (until {formatted_date}). "
|
||||
except:
|
||||
early_access_msg = "This model requires early access payment. "
|
||||
early_access_msg = "This model requires payment. "
|
||||
|
||||
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
|
||||
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
|
||||
@@ -131,33 +294,23 @@ class DownloadManager:
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||
if not file_info:
|
||||
return {'success': False, 'error': 'No primary file found in metadata'}
|
||||
if not file_info.get('downloadUrl'):
|
||||
return {'success': False, 'error': 'No download URL found for primary file'}
|
||||
|
||||
# 3. Prepare download
|
||||
file_name = file_info['name']
|
||||
save_path = os.path.join(save_dir, file_name)
|
||||
|
||||
# 4. Notify file monitor - use normalized path and file size
|
||||
# file monitor is despreted, so we don't need to use it
|
||||
|
||||
# 5. Prepare metadata based on model type
|
||||
if model_type == "checkpoint":
|
||||
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
||||
else:
|
||||
elif model_type == "lora":
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||
|
||||
# 5.1 Get and update model tags, description and creator info
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
if model_metadata.get("tags"):
|
||||
metadata.tags = model_metadata.get("tags", [])
|
||||
if model_metadata.get("description"):
|
||||
metadata.modelDescription = model_metadata.get("description", "")
|
||||
if model_metadata.get("creator"):
|
||||
metadata.civitai["creator"] = model_metadata.get("creator")
|
||||
elif model_type == "embedding":
|
||||
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
@@ -167,9 +320,14 @@ class DownloadManager:
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
model_type=model_type,
|
||||
download_id=download_id
|
||||
)
|
||||
|
||||
# If early_access_msg exists and download failed, replace error message
|
||||
if 'early_access_msg' in locals() and not result.get('success', False):
|
||||
result['error'] = early_access_msg
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -180,15 +338,98 @@ class DownloadManager:
|
||||
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
|
||||
"""Calculate relative path using template from settings
|
||||
|
||||
Args:
|
||||
version_info: Version info from Civitai API
|
||||
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
||||
|
||||
Returns:
|
||||
Relative path string
|
||||
"""
|
||||
# Get path template from settings for specific model type
|
||||
path_template = settings.get_download_path_template(model_type)
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
return ''
|
||||
|
||||
# Get base model name
|
||||
base_model = version_info.get('baseModel', '')
|
||||
|
||||
# Get author from creator data
|
||||
creator_info = version_info.get('creator')
|
||||
if creator_info and isinstance(creator_info, dict):
|
||||
author = creator_info.get('username') or 'Anonymous'
|
||||
else:
|
||||
author = 'Anonymous'
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Get model tags
|
||||
model_tags = version_info.get('model', {}).get('tags', [])
|
||||
|
||||
# Find the first Civitai model tag that exists in model_tags
|
||||
first_tag = ''
|
||||
for civitai_tag in CIVITAI_MODEL_TAGS:
|
||||
if civitai_tag in model_tags:
|
||||
first_tag = civitai_tag
|
||||
break
|
||||
|
||||
# If no Civitai model tag found, fallback to first tag
|
||||
if not first_tag and model_tags:
|
||||
first_tag = model_tags[0]
|
||||
|
||||
# Format the template with available data
|
||||
formatted_path = path_template
|
||||
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
||||
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
||||
formatted_path = formatted_path.replace('{author}', author)
|
||||
|
||||
return formatted_path
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
civitai_client = await self._get_civitai_client()
|
||||
save_path = metadata.file_path
|
||||
# Extract original filename details
|
||||
original_filename = os.path.basename(metadata.file_path)
|
||||
base_name, extension = os.path.splitext(original_filename)
|
||||
|
||||
# Check for filename conflicts and generate unique filename if needed
|
||||
# Use the hash from metadata for conflict resolution
|
||||
def hash_provider():
|
||||
return metadata.sha256
|
||||
|
||||
unique_filename = metadata.generate_unique_filename(
|
||||
save_dir,
|
||||
base_name,
|
||||
extension,
|
||||
hash_provider=hash_provider
|
||||
)
|
||||
|
||||
# Update paths if filename changed
|
||||
if unique_filename != original_filename:
|
||||
logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'")
|
||||
save_path = os.path.join(save_dir, unique_filename)
|
||||
# Update metadata with new file path and name
|
||||
metadata.file_path = save_path.replace(os.sep, '/')
|
||||
metadata.file_name = os.path.splitext(unique_filename)[0]
|
||||
else:
|
||||
save_path = metadata.file_path
|
||||
|
||||
part_path = save_path + '.part'
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
# Store file paths in active_downloads for potential cleanup
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['file_path'] = save_path
|
||||
self._active_downloads[download_id]['part_path'] = part_path
|
||||
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
@@ -205,19 +446,31 @@ class DownloadManager:
|
||||
preview_ext = '.mp4'
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
|
||||
# Download video directly
|
||||
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
# Download video directly using downloader
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.download_file(
|
||||
images[0]['url'],
|
||||
preview_path,
|
||||
use_auth=False # Preview images typically don't need auth
|
||||
)
|
||||
if success:
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
# For images, use WebP format for better performance
|
||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Download the original image to temp path
|
||||
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
|
||||
# Download the original image to temp path using downloader
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
images[0]['url'],
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
# Save to temp file
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(content)
|
||||
# Optimize and convert to WebP
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
@@ -237,8 +490,6 @@ class DownloadManager:
|
||||
# Update metadata
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Remove temporary file
|
||||
try:
|
||||
@@ -250,35 +501,52 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
# Download model file with progress tracking
|
||||
success, result = await civitai_client._download_file(
|
||||
# Download model file with progress tracking using downloader
|
||||
downloader = await get_downloader()
|
||||
# Determine if the download URL is from Civitai
|
||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_dir,
|
||||
os.path.basename(save_path),
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||
)
|
||||
|
||||
if not success:
|
||||
# Clean up files on failure
|
||||
for path in [save_path, metadata_path, metadata.preview_url]:
|
||||
# Clean up files on failure, but preserve .part file for resume
|
||||
cleanup_files = [metadata_path]
|
||||
if metadata.preview_url and os.path.exists(metadata.preview_url):
|
||||
cleanup_files.append(metadata.preview_url)
|
||||
|
||||
for path in cleanup_files:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||
|
||||
# Log but don't remove .part file to allow resume
|
||||
if os.path.exists(part_path):
|
||||
logger.info(f"Preserving partial download for resume: {part_path}")
|
||||
|
||||
return {'success': False, 'error': result}
|
||||
|
||||
# 4. Update file information (size and modified time)
|
||||
metadata.update_file_info(save_path)
|
||||
|
||||
# 5. Final metadata update
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
await MetadataManager.save_metadata(save_path, metadata)
|
||||
|
||||
# 6. Update cache based on model type
|
||||
if model_type == "checkpoint":
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
else:
|
||||
elif model_type == "lora":
|
||||
scanner = await self._get_lora_scanner()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
elif model_type == "embedding":
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
logger.info(f"Updating embedding cache for {save_path}")
|
||||
|
||||
# Convert metadata to dictionary
|
||||
metadata_dict = metadata.to_dict()
|
||||
@@ -296,10 +564,18 @@ class DownloadManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||
# Clean up partial downloads
|
||||
for path in [save_path, metadata_path]:
|
||||
# Clean up partial downloads except .part file
|
||||
cleanup_files = [metadata_path]
|
||||
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
|
||||
cleanup_files.append(metadata.preview_url)
|
||||
|
||||
for path in cleanup_files:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
||||
@@ -312,4 +588,99 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
# Scale file progress to 3-100 range (after preview download)
|
||||
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
||||
await progress_callback(round(overall_progress))
|
||||
await progress_callback(round(overall_progress))
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict:
|
||||
"""Cancel an active download by download_id
|
||||
|
||||
Args:
|
||||
download_id: The unique identifier of the download task
|
||||
|
||||
Returns:
|
||||
Dict: Status of the cancellation operation
|
||||
"""
|
||||
if download_id not in self._download_tasks:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
try:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
task.cancel()
|
||||
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||
|
||||
# Wait briefly for the task to acknowledge cancellation
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
# Clean up ALL files including .part when user cancels
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info:
|
||||
# Delete the main file
|
||||
if 'file_path' in download_info:
|
||||
file_path = download_info['file_path']
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
logger.debug(f"Deleted cancelled download: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file: {e}")
|
||||
|
||||
# Delete the .part file (only on user cancellation)
|
||||
if 'part_path' in download_info:
|
||||
part_path = download_info['part_path']
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.unlink(part_path)
|
||||
logger.debug(f"Deleted partial download: {part_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting part file: {e}")
|
||||
|
||||
# Delete metadata file if exists
|
||||
if 'file_path' in download_info:
|
||||
file_path = download_info['file_path']
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
os.unlink(metadata_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting metadata file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4)
|
||||
for preview_ext in ['.webp', '.mp4']:
|
||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||
if os.path.exists(preview_path):
|
||||
try:
|
||||
os.unlink(preview_path)
|
||||
logger.debug(f"Deleted preview file: {preview_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting preview file: {e}")
|
||||
|
||||
return {'success': True, 'message': 'Download cancelled successfully'}
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def get_active_downloads(self) -> Dict:
|
||||
"""Get information about all active downloads
|
||||
|
||||
Returns:
|
||||
Dict: List of active downloads and their status
|
||||
"""
|
||||
return {
|
||||
'downloads': [
|
||||
{
|
||||
'download_id': task_id,
|
||||
'model_id': info.get('model_id'),
|
||||
'model_version_id': info.get('model_version_id'),
|
||||
'progress': info.get('progress', 0),
|
||||
'status': info.get('status', 'unknown'),
|
||||
'error': info.get('error', None)
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
}
|
||||
539
py/services/downloader.py
Normal file
539
py/services/downloader.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
Unified download manager for all HTTP/HTTPS downloads in the application.
|
||||
|
||||
This module provides a centralized download service with:
|
||||
- Singleton pattern for global session management
|
||||
- Support for authenticated downloads (e.g., CivitAI API key)
|
||||
- Resumable downloads with automatic retry
|
||||
- Progress tracking and callbacks
|
||||
- Optimized connection pooling and timeouts
|
||||
- Unified error handling and logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Tuple, Callable, Union
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Downloader:
|
||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of Downloader"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the downloader with optimal settings"""
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# Session management
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None # Store proxy URL for current session
|
||||
|
||||
# Configuration
|
||||
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
|
||||
self.max_retries = 5
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
|
||||
# Default headers
|
||||
self.default_headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||
}
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create the global aiohttp session with optimized settings"""
|
||||
if self._session is None or self._should_refresh_session():
|
||||
await self._create_session()
|
||||
return self._session
|
||||
|
||||
@property
|
||||
def proxy_url(self) -> Optional[str]:
|
||||
"""Get the current proxy URL (initialize if needed)"""
|
||||
if not hasattr(self, '_proxy_url'):
|
||||
self._proxy_url = None
|
||||
return self._proxy_url
|
||||
|
||||
def _should_refresh_session(self) -> bool:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
return True
|
||||
|
||||
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
|
||||
return True
|
||||
|
||||
# Refresh if session is older than timeout
|
||||
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _create_session(self):
|
||||
"""Create a new aiohttp session with optimized settings"""
|
||||
# Close existing session if any
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
|
||||
# Check for app-level proxy settings
|
||||
proxy_url = None
|
||||
if settings.get('proxy_enabled', False):
|
||||
proxy_host = settings.get('proxy_host', '').strip()
|
||||
proxy_port = settings.get('proxy_port', '').strip()
|
||||
proxy_type = settings.get('proxy_type', 'http').lower()
|
||||
proxy_username = settings.get('proxy_username', '').strip()
|
||||
proxy_password = settings.get('proxy_password', '').strip()
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
# Build proxy URL
|
||||
if proxy_username and proxy_password:
|
||||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||
else:
|
||||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
logger.debug(f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}")
|
||||
logger.debug("Proxy mode: app-level proxy is active.")
|
||||
else:
|
||||
logger.debug("Proxy mode: system-level proxy (trust_env) will be used if configured in environment.")
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=8, # Concurrent connections
|
||||
ttl_dns_cache=300, # DNS cache timeout
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=None, # No total timeout for large downloads
|
||||
connect=60, # Connection timeout
|
||||
sock_read=300 # 5 minute socket read timeout
|
||||
)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=proxy_url is None, # Only use system proxy if no app-level proxy is set
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Store proxy URL for use in requests
|
||||
self._proxy_url = proxy_url
|
||||
self._session_created_at = datetime.now()
|
||||
|
||||
logger.debug("Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s", bool(proxy_url), proxy_url is None)
|
||||
|
||||
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
|
||||
"""Get headers with optional authentication"""
|
||||
headers = self.default_headers.copy()
|
||||
|
||||
if use_auth:
|
||||
# Add CivitAI API key if available
|
||||
api_key = settings.get('civitai_api_key')
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
return headers
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: str,
|
||||
progress_callback: Optional[Callable[[float], None]] = None,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
allow_resume: bool = True
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Download a file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
save_path: Full path where the file should be saved
|
||||
progress_callback: Optional callback for progress updates (0-100)
|
||||
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||
custom_headers: Additional headers to include in request
|
||||
allow_resume: Whether to support resumable downloads
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
retry_count = 0
|
||||
part_path = save_path + '.part' if allow_resume else save_path
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Get existing file size for resume
|
||||
resume_offset = 0
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
total_size = 0
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_file] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[download_file] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Add Range header for resume if we have partial data
|
||||
request_headers = headers.copy()
|
||||
if allow_resume and resume_offset > 0:
|
||||
request_headers['Range'] = f'bytes={resume_offset}-'
|
||||
|
||||
# Disable compression for better chunked downloads
|
||||
request_headers['Accept-Encoding'] = 'identity'
|
||||
|
||||
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
|
||||
if resume_offset > 0:
|
||||
logger.debug(f"Requesting range from byte {resume_offset}")
|
||||
|
||||
async with session.get(url, headers=request_headers, allow_redirects=True, proxy=self.proxy_url) as response:
|
||||
# Handle different response codes
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning("Server doesn't support range requests, restarting download")
|
||||
resume_offset = 0
|
||||
if os.path.exists(part_path):
|
||||
os.remove(part_path)
|
||||
elif response.status == 206:
|
||||
# Partial content response (resume successful)
|
||||
content_range = response.headers.get('Content-Range')
|
||||
if content_range:
|
||||
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
||||
range_parts = content_range.split('/')
|
||||
if len(range_parts) == 2:
|
||||
total_size = int(range_parts[1])
|
||||
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable - file might be complete or corrupted
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
part_size = os.path.getsize(part_path)
|
||||
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
||||
# Try to get actual file size
|
||||
head_response = await session.head(url, headers=headers, proxy=self.proxy_url)
|
||||
if head_response.status == 200:
|
||||
actual_size = int(head_response.headers.get('content-length', 0))
|
||||
if part_size == actual_size:
|
||||
# File is complete, just rename it
|
||||
if allow_resume:
|
||||
os.rename(part_path, save_path)
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
return True, save_path
|
||||
# Remove corrupted part file and restart
|
||||
os.remove(part_path)
|
||||
resume_offset = 0
|
||||
continue
|
||||
elif response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
return False, "Invalid or missing API key, or early access restriction."
|
||||
elif response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
elif response.status == 404:
|
||||
logger.warning(f"Resource not found: {url} (Status 404)")
|
||||
return False, "File not found - the download link may be invalid or expired."
|
||||
else:
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
# Get total file size for progress calculation (if not set from Content-Range)
|
||||
if total_size == 0:
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
if response.status == 206:
|
||||
# For partial content, add the offset to get total file size
|
||||
total_size += resume_offset
|
||||
|
||||
current_size = resume_offset
|
||||
last_progress_report_time = datetime.now()
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
# Stream download to file with progress updates
|
||||
loop = asyncio.get_running_loop()
|
||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||
with open(part_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if chunk:
|
||||
# Run blocking file write in executor
|
||||
await loop.run_in_executor(None, f.write, chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 1.0:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size if total_size was provided
|
||||
final_size = os.path.getsize(part_path)
|
||||
if total_size > 0 and final_size != total_size:
|
||||
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
|
||||
# Don't treat this as fatal error, continue anyway
|
||||
|
||||
# Atomically rename .part to final file (only if using resume)
|
||||
if allow_resume and part_path != save_path:
|
||||
max_rename_attempts = 5
|
||||
rename_attempt = 0
|
||||
rename_success = False
|
||||
|
||||
while rename_attempt < max_rename_attempts and not rename_success:
|
||||
try:
|
||||
# If the destination file exists, remove it first (Windows safe)
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
os.rename(part_path, save_path)
|
||||
rename_success = True
|
||||
except PermissionError as e:
|
||||
rename_attempt += 1
|
||||
if rename_attempt < max_rename_attempts:
|
||||
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||
return False, f"Failed to finalize download: {str(e)}"
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return True, save_path
|
||||
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
||||
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Update resume offset for next attempt
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Will resume from byte {resume_offset}")
|
||||
|
||||
# Refresh session to get new connection
|
||||
await self._create_session()
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for download: {e}")
|
||||
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
return False, f"Download failed after {self.max_retries + 1} attempts"
|
||||
|
||||
async def download_to_memory(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
return_headers: bool = False
|
||||
) -> Tuple[bool, Union[bytes, str], Optional[Dict]]:
|
||||
"""
|
||||
Download a file to memory (for small files like preview images)
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
return_headers: Whether to return response headers along with content
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_to_memory] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[download_to_memory] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.get(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
if return_headers:
|
||||
return True, content, dict(response.headers)
|
||||
else:
|
||||
return True, content, None
|
||||
elif response.status == 401:
|
||||
error_msg = "Unauthorized access - invalid or missing API key"
|
||||
return False, error_msg, None
|
||||
elif response.status == 403:
|
||||
error_msg = "Access forbidden"
|
||||
return False, error_msg, None
|
||||
elif response.status == 404:
|
||||
error_msg = "File not found"
|
||||
return False, error_msg, None
|
||||
else:
|
||||
error_msg = f"Download failed with status {response.status}"
|
||||
return False, error_msg, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||
return False, str(e), None
|
||||
|
||||
async def get_response_headers(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Get response headers without downloading the full content
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[get_response_headers] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[get_response_headers] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.head(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
if response.status == 200:
|
||||
return True, dict(response.headers)
|
||||
else:
|
||||
return False, f"Head request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting headers from {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def make_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Make a generic HTTP request and return JSON response
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: Request URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
**kwargs: Additional arguments for aiohttp request
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[make_request] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# Add proxy to kwargs if not already present
|
||||
if 'proxy' not in kwargs:
|
||||
kwargs['proxy'] = self.proxy_url
|
||||
|
||||
async with session.request(method, url, headers=headers, **kwargs) as response:
|
||||
if response.status == 200:
|
||||
# Try to parse as JSON, fall back to text
|
||||
try:
|
||||
data = await response.json()
|
||||
return True, data
|
||||
except:
|
||||
text = await response.text()
|
||||
return True, text
|
||||
elif response.status == 401:
|
||||
return False, "Unauthorized access - invalid or missing API key"
|
||||
elif response.status == 403:
|
||||
return False, "Access forbidden"
|
||||
elif response.status == 404:
|
||||
return False, "Resource not found"
|
||||
else:
|
||||
return False, f"Request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making {method} request to {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None
|
||||
logger.debug("Closed HTTP session")
|
||||
|
||||
async def refresh_session(self):
|
||||
"""Force refresh the HTTP session (useful when proxy settings change)"""
|
||||
await self._create_session()
|
||||
logger.info("HTTP session refreshed due to settings change")
|
||||
|
||||
|
||||
# Global instance accessor
|
||||
async def get_downloader() -> Downloader:
|
||||
"""Get the global downloader instance"""
|
||||
return await Downloader.get_instance()
|
||||
26
py/services/embedding_scanner.py
Normal file
26
py/services/embedding_scanner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingScanner(ModelScanner):
|
||||
"""Service for scanning and managing embedding files"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
super().__init__(
|
||||
model_type="embedding",
|
||||
model_class=EmbeddingMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get embedding root directories"""
|
||||
return config.embeddings_roots
|
||||
49
py/services/embedding_service.py
Normal file
49
py/services/embedding_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingService(BaseModelService):
|
||||
"""Embedding-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Embedding service
|
||||
|
||||
Args:
|
||||
scanner: Embedding scanner instance
|
||||
"""
|
||||
super().__init__("embedding", scanner, EmbeddingMetadata)
|
||||
|
||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||
"""Format Embedding data for API response"""
|
||||
return {
|
||||
"model_name": embedding_data["model_name"],
|
||||
"file_name": embedding_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
||||
"base_model": embedding_data.get("base_model", ""),
|
||||
"folder": embedding_data["folder"],
|
||||
"sha256": embedding_data.get("sha256", ""),
|
||||
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": embedding_data.get("size", 0),
|
||||
"modified": embedding_data.get("modified", ""),
|
||||
"tags": embedding_data.get("tags", []),
|
||||
"from_civitai": embedding_data.get("from_civitai", True),
|
||||
"notes": embedding_data.get("notes", ""),
|
||||
"model_type": embedding_data.get("model_type", "embedding"),
|
||||
"favorite": embedding_data.get("favorite", False),
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Embeddings with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Embeddings with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
246
py/services/example_images_cleanup_service.py
Normal file
246
py/services/example_images_cleanup_service.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Service for cleaning up example image folders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CleanupResult:
|
||||
"""Structured result returned from cleanup operations."""
|
||||
|
||||
success: bool
|
||||
checked_folders: int
|
||||
moved_empty_folders: int
|
||||
moved_orphaned_folders: int
|
||||
skipped_non_hash: int
|
||||
move_failures: int
|
||||
errors: List[str]
|
||||
deleted_root: str | None
|
||||
partial_success: bool
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
"""Convert the dataclass to a serialisable dictionary."""
|
||||
|
||||
data = {
|
||||
"success": self.success,
|
||||
"checked_folders": self.checked_folders,
|
||||
"moved_empty_folders": self.moved_empty_folders,
|
||||
"moved_orphaned_folders": self.moved_orphaned_folders,
|
||||
"moved_total": self.moved_empty_folders + self.moved_orphaned_folders,
|
||||
"skipped_non_hash": self.skipped_non_hash,
|
||||
"move_failures": self.move_failures,
|
||||
"errors": self.errors,
|
||||
"deleted_root": self.deleted_root,
|
||||
"partial_success": self.partial_success,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ExampleImagesCleanupService:
|
||||
"""Encapsulates logic for cleaning example image folders."""
|
||||
|
||||
DELETED_FOLDER_NAME = "_deleted"
|
||||
|
||||
def __init__(self, deleted_folder_name: str | None = None) -> None:
|
||||
self._deleted_folder_name = deleted_folder_name or self.DELETED_FOLDER_NAME
|
||||
|
||||
async def cleanup_example_image_folders(self) -> Dict[str, object]:
|
||||
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
|
||||
|
||||
example_images_path = settings.get("example_images_path")
|
||||
if not example_images_path:
|
||||
logger.debug("Cleanup skipped: example images path not configured")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Example images path is not configured.",
|
||||
"error_code": "path_not_configured",
|
||||
}
|
||||
|
||||
example_root = Path(example_images_path)
|
||||
if not example_root.exists():
|
||||
logger.debug("Cleanup skipped: example images path missing -> %s", example_root)
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Example images path does not exist.",
|
||||
"error_code": "path_not_found",
|
||||
}
|
||||
|
||||
try:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.error("Failed to acquire scanners for cleanup: %s", exc, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to load model scanners: {exc}",
|
||||
"error_code": "scanner_initialization_failed",
|
||||
}
|
||||
|
||||
deleted_bucket = example_root / self._deleted_folder_name
|
||||
deleted_bucket.mkdir(exist_ok=True)
|
||||
|
||||
checked_folders = 0
|
||||
moved_empty = 0
|
||||
moved_orphaned = 0
|
||||
skipped_non_hash = 0
|
||||
move_failures = 0
|
||||
errors: List[str] = []
|
||||
|
||||
for entry in os.scandir(example_root):
|
||||
if not entry.is_dir(follow_symlinks=False):
|
||||
continue
|
||||
|
||||
if entry.name == self._deleted_folder_name:
|
||||
continue
|
||||
|
||||
checked_folders += 1
|
||||
folder_path = Path(entry.path)
|
||||
|
||||
try:
|
||||
if self._is_folder_empty(folder_path):
|
||||
if await self._remove_empty_folder(folder_path):
|
||||
moved_empty += 1
|
||||
else:
|
||||
move_failures += 1
|
||||
continue
|
||||
|
||||
if not self._is_hash_folder(entry.name):
|
||||
skipped_non_hash += 1
|
||||
continue
|
||||
|
||||
hash_exists = (
|
||||
lora_scanner.has_hash(entry.name)
|
||||
or checkpoint_scanner.has_hash(entry.name)
|
||||
or embedding_scanner.has_hash(entry.name)
|
||||
)
|
||||
|
||||
if not hash_exists:
|
||||
if await self._move_folder(folder_path, deleted_bucket):
|
||||
moved_orphaned += 1
|
||||
else:
|
||||
move_failures += 1
|
||||
|
||||
except Exception as exc: # pragma: no cover - filesystem guard
|
||||
move_failures += 1
|
||||
error_message = f"{entry.name}: {exc}"
|
||||
errors.append(error_message)
|
||||
logger.error("Error processing example images folder %s: %s", folder_path, exc, exc_info=True)
|
||||
|
||||
partial_success = move_failures > 0 and (moved_empty > 0 or moved_orphaned > 0)
|
||||
success = move_failures == 0 and not errors
|
||||
|
||||
result = CleanupResult(
|
||||
success=success,
|
||||
checked_folders=checked_folders,
|
||||
moved_empty_folders=moved_empty,
|
||||
moved_orphaned_folders=moved_orphaned,
|
||||
skipped_non_hash=skipped_non_hash,
|
||||
move_failures=move_failures,
|
||||
errors=errors,
|
||||
deleted_root=str(deleted_bucket),
|
||||
partial_success=partial_success,
|
||||
)
|
||||
|
||||
summary = result.to_dict()
|
||||
if success:
|
||||
logger.info(
|
||||
"Example images cleanup complete: checked=%s, moved_empty=%s, moved_orphaned=%s",
|
||||
checked_folders,
|
||||
moved_empty,
|
||||
moved_orphaned,
|
||||
)
|
||||
elif partial_success:
|
||||
logger.warning(
|
||||
"Example images cleanup partially complete: moved=%s, failures=%s",
|
||||
summary["moved_total"],
|
||||
move_failures,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Example images cleanup failed: move_failures=%s, errors=%s",
|
||||
move_failures,
|
||||
errors,
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def _is_folder_empty(folder_path: Path) -> bool:
|
||||
try:
|
||||
with os.scandir(folder_path) as iterator:
|
||||
return not any(iterator)
|
||||
except FileNotFoundError:
|
||||
return True
|
||||
except OSError as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to inspect folder %s: %s", folder_path, exc)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_hash_folder(name: str) -> bool:
|
||||
if len(name) != 64:
|
||||
return False
|
||||
hex_chars = set("0123456789abcdefABCDEF")
|
||||
return all(char in hex_chars for char in name)
|
||||
|
||||
async def _remove_empty_folder(self, folder_path: Path) -> bool:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
shutil.rmtree,
|
||||
str(folder_path),
|
||||
)
|
||||
logger.debug("Removed empty example images folder %s", folder_path)
|
||||
return True
|
||||
except Exception as exc: # pragma: no cover - filesystem guard
|
||||
logger.error("Failed to remove empty example images folder %s: %s", folder_path, exc, exc_info=True)
|
||||
return False
|
||||
|
||||
async def _move_folder(self, folder_path: Path, deleted_bucket: Path) -> bool:
|
||||
destination = self._build_destination(folder_path.name, deleted_bucket)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
shutil.move,
|
||||
str(folder_path),
|
||||
str(destination),
|
||||
)
|
||||
logger.debug("Moved example images folder %s -> %s", folder_path, destination)
|
||||
return True
|
||||
except Exception as exc: # pragma: no cover - filesystem guard
|
||||
logger.error(
|
||||
"Failed to move example images folder %s to %s: %s",
|
||||
folder_path,
|
||||
destination,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
def _build_destination(self, folder_name: str, deleted_bucket: Path) -> Path:
|
||||
destination = deleted_bucket / folder_name
|
||||
suffix = 1
|
||||
|
||||
while destination.exists():
|
||||
destination = deleted_bucket / f"{folder_name}_{suffix}"
|
||||
suffix += 1
|
||||
|
||||
return destination
|
||||
@@ -1,542 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from typing import List, Dict, Set, Optional
|
||||
from threading import Lock
|
||||
|
||||
from ..config import config
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration constant to control file monitoring functionality
|
||||
ENABLE_FILE_MONITORING = False
|
||||
|
||||
class BaseFileHandler(FileSystemEventHandler):
|
||||
"""Base handler for file system events"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self.loop = loop # Store event loop reference
|
||||
self.pending_changes = set() # Pending changes
|
||||
self.lock = Lock() # Thread-safe lock
|
||||
self.update_task = None # Async update task
|
||||
self._ignore_paths = set() # Paths to ignore
|
||||
self._min_ignore_timeout = 5 # Minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
|
||||
|
||||
# Track modified files with timestamps for debouncing
|
||||
self.modified_files: Dict[str, float] = {}
|
||||
self.debounce_timer = None
|
||||
self.debounce_delay = 3.0 # Seconds to wait after last modification
|
||||
|
||||
# Track files already scheduled for processing
|
||||
self.scheduled_files: Set[str] = set()
|
||||
|
||||
# File extensions to monitor - should be overridden by subclasses
|
||||
self.file_extensions = set()
|
||||
|
||||
def _should_ignore(self, path: str) -> bool:
|
||||
"""Check if path should be ignored"""
|
||||
real_path = os.path.realpath(path) # Resolve any symbolic links
|
||||
return real_path.replace(os.sep, '/') in self._ignore_paths
|
||||
|
||||
def add_ignore_path(self, path: str, file_size: int = 0):
|
||||
"""Add path to ignore list with dynamic timeout based on file size"""
|
||||
real_path = os.path.realpath(path) # Resolve any symbolic links
|
||||
self._ignore_paths.add(real_path.replace(os.sep, '/'))
|
||||
|
||||
# Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
|
||||
timeout = 5
|
||||
|
||||
self.loop.call_later(
|
||||
timeout,
|
||||
self._ignore_paths.discard,
|
||||
real_path.replace(os.sep, '/')
|
||||
)
|
||||
|
||||
def on_created(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Handle appropriate files based on extensions
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
# Process this file directly and ignore subsequent modifications
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"File created: {event.src_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
# Ignore modifications for a short period after creation
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
normalized_path
|
||||
)
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Only process files with supported extensions
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
|
||||
# Skip if this file is already scheduled for processing
|
||||
if normalized_path in self.scheduled_files:
|
||||
return
|
||||
|
||||
# Update the timestamp for this file
|
||||
self.modified_files[normalized_path] = time.time()
|
||||
|
||||
# Cancel any existing timer
|
||||
if self.debounce_timer:
|
||||
self.debounce_timer.cancel()
|
||||
|
||||
# Set a new timer to process modified files after debounce period
|
||||
self.debounce_timer = self.loop.call_later(
|
||||
self.debounce_delay,
|
||||
self.loop.call_soon_threadsafe,
|
||||
self._process_modified_files
|
||||
)
|
||||
|
||||
def _process_modified_files(self):
|
||||
"""Process files that have been modified after debounce period"""
|
||||
current_time = time.time()
|
||||
files_to_process = []
|
||||
|
||||
# Find files that haven't been modified for debounce_delay seconds
|
||||
for file_path, last_modified in list(self.modified_files.items()):
|
||||
if current_time - last_modified >= self.debounce_delay:
|
||||
# Only process if not already scheduled
|
||||
if file_path not in self.scheduled_files:
|
||||
files_to_process.append(file_path)
|
||||
self.scheduled_files.add(file_path)
|
||||
|
||||
# Auto-remove from scheduled list after reasonable time
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
file_path
|
||||
)
|
||||
|
||||
del self.modified_files[file_path]
|
||||
|
||||
# Process stable files
|
||||
for file_path in files_to_process:
|
||||
logger.info(f"Processing modified file: {file_path}")
|
||||
self._schedule_update('add', file_path)
|
||||
|
||||
def on_deleted(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext not in self.file_extensions:
|
||||
return
|
||||
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
# Remove from scheduled files if present
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
self.scheduled_files.discard(normalized_path)
|
||||
|
||||
logger.info(f"File deleted: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def on_moved(self, event):
|
||||
"""Handle file move/rename events"""
|
||||
|
||||
src_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
||||
|
||||
# If destination has supported extension, treat as new file
|
||||
if dest_ext in self.file_extensions:
|
||||
if self._should_ignore(event.dest_path):
|
||||
return
|
||||
|
||||
normalized_path = os.path.realpath(event.dest_path).replace(os.sep, '/')
|
||||
|
||||
# Only process if not already scheduled
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"File renamed/moved to: {event.dest_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.dest_path)
|
||||
|
||||
# Auto-remove from scheduled list after reasonable time
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
normalized_path
|
||||
)
|
||||
|
||||
# If source was a supported file, treat it as deleted
|
||||
if src_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
self.scheduled_files.discard(normalized_path)
|
||||
|
||||
logger.info(f"File moved/renamed from: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def _schedule_update(self, action: str, file_path: str):
|
||||
"""Schedule a cache update"""
|
||||
with self.lock:
|
||||
# Use config method to map path
|
||||
mapped_path = config.map_path_to_link(file_path)
|
||||
normalized_path = mapped_path.replace(os.sep, '/')
|
||||
self.pending_changes.add((action, normalized_path))
|
||||
|
||||
self.loop.call_soon_threadsafe(self._create_update_task)
|
||||
|
||||
def _create_update_task(self):
|
||||
"""Create update task in the event loop"""
|
||||
if self.update_task is None or self.update_task.done():
|
||||
self.update_task = asyncio.create_task(self._process_changes())
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing - should be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement _process_changes")
|
||||
|
||||
|
||||
class LoraFileHandler(BaseFileHandler):
|
||||
"""Handler for LoRA file system events"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
# Set supported file extensions for LoRAs
|
||||
self.file_extensions = {'.safetensors'}
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
with self.lock:
|
||||
changes = self.pending_changes.copy()
|
||||
self.pending_changes.clear()
|
||||
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} LoRA file changes")
|
||||
|
||||
# Get scanner through ServiceRegistry
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# Check if file already exists in cache
|
||||
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if existing:
|
||||
logger.info(f"File {file_path} already in cache, skipping")
|
||||
continue
|
||||
|
||||
# Scan new file
|
||||
model_data = await scanner.scan_single_model(file_path)
|
||||
if model_data:
|
||||
# Update tags count
|
||||
for tag in model_data.get('tags', []):
|
||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(model_data)
|
||||
new_folders.add(model_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in model_data:
|
||||
scanner._hash_index.add_entry(
|
||||
model_data['sha256'],
|
||||
model_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the model to remove so we can update tags count
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from cache")
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {action} for {file_path}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Update folder list
|
||||
all_folders = set(cache.folders) | new_folders
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes for LoRA: {e}")
|
||||
|
||||
|
||||
class CheckpointFileHandler(BaseFileHandler):
|
||||
"""Handler for checkpoint file system events"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
# Set supported file extensions for checkpoints
|
||||
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing for checkpoint files"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
with self.lock:
|
||||
changes = self.pending_changes.copy()
|
||||
self.pending_changes.clear()
|
||||
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} checkpoint file changes")
|
||||
|
||||
# Get scanner through ServiceRegistry
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# Check if file already exists in cache
|
||||
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if existing:
|
||||
logger.info(f"File {file_path} already in cache, skipping")
|
||||
continue
|
||||
|
||||
# Scan new file
|
||||
model_data = await scanner.scan_single_model(file_path)
|
||||
if model_data:
|
||||
# Update tags count if applicable
|
||||
for tag in model_data.get('tags', []):
|
||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(model_data)
|
||||
new_folders.add(model_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in model_data:
|
||||
scanner._hash_index.add_entry(
|
||||
model_data['sha256'],
|
||||
model_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the model to remove so we can update tags count
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from checkpoint cache")
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Update folder list
|
||||
all_folders = set(cache.folders) | new_folders
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes for checkpoint: {e}")
|
||||
|
||||
|
||||
class BaseFileMonitor:
|
||||
"""Base class for file monitoring"""
|
||||
|
||||
def __init__(self, monitor_paths: List[str]):
|
||||
self.observer = Observer()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.monitor_paths = set()
|
||||
|
||||
# Process monitor paths
|
||||
for path in monitor_paths:
|
||||
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
|
||||
|
||||
# Add mapped paths from config
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
def start(self):
|
||||
"""Start file monitoring"""
|
||||
if not ENABLE_FILE_MONITORING:
|
||||
logger.debug("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
|
||||
return
|
||||
|
||||
for path in self.monitor_paths:
|
||||
try:
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
logger.info(f"Started monitoring: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring {path}: {e}")
|
||||
|
||||
self.observer.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop file monitoring"""
|
||||
if not ENABLE_FILE_MONITORING:
|
||||
return
|
||||
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
def rescan_links(self):
|
||||
"""Rescan links when new ones are added"""
|
||||
if not ENABLE_FILE_MONITORING:
|
||||
return
|
||||
|
||||
# Find new paths not yet being monitored
|
||||
new_paths = set()
|
||||
for path in config._path_mappings.keys():
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
if real_path not in self.monitor_paths:
|
||||
new_paths.add(real_path)
|
||||
self.monitor_paths.add(real_path)
|
||||
|
||||
# Add new paths to monitoring
|
||||
for path in new_paths:
|
||||
try:
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
logger.info(f"Added new monitoring path: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding new monitor for {path}: {e}")
|
||||
|
||||
|
||||
class LoraFileMonitor(BaseFileMonitor):
|
||||
"""Monitor for LoRA file changes"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls, monitor_paths=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, monitor_paths=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
if monitor_paths is None:
|
||||
from ..config import config
|
||||
monitor_paths = config.loras_roots
|
||||
|
||||
super().__init__(monitor_paths)
|
||||
self.handler = LoraFileHandler(self.loop)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
from ..config import config
|
||||
cls._instance = cls(config.loras_roots)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class CheckpointFileMonitor(BaseFileMonitor):
|
||||
"""Monitor for checkpoint file changes"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls, monitor_paths=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, monitor_paths=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
if monitor_paths is None:
|
||||
# Get checkpoint roots from scanner
|
||||
monitor_paths = []
|
||||
# We'll initialize monitor paths later when scanner is available
|
||||
|
||||
super().__init__(monitor_paths or [])
|
||||
self.handler = CheckpointFileHandler(self.loop)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls([])
|
||||
|
||||
# Now get checkpoint roots from scanner
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
monitor_paths = scanner.get_model_roots()
|
||||
|
||||
# Update monitor paths - but don't actually monitor them
|
||||
for path in monitor_paths:
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
cls._instance.monitor_paths.add(real_path)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def start(self):
|
||||
"""Override start to check global enable flag"""
|
||||
if not ENABLE_FILE_MONITORING:
|
||||
logger.debug("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
|
||||
return
|
||||
|
||||
logger.debug("Checkpoint file monitoring is temporarily disabled")
|
||||
# Skip the actual monitoring setup
|
||||
pass
|
||||
|
||||
async def initialize_paths(self):
|
||||
"""Initialize monitor paths from scanner - currently disabled"""
|
||||
if not ENABLE_FILE_MONITORING:
|
||||
logger.debug("Checkpoint path initialization skipped (monitoring disabled)")
|
||||
return
|
||||
|
||||
logger.debug("Checkpoint file path initialization skipped (monitoring disabled)")
|
||||
pass
|
||||
@@ -1,65 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
@dataclass
|
||||
class LoraCache:
|
||||
"""Cache structure for LoRA data"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview_url for a specific lora in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the lora wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
else:
|
||||
return False # Lora not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
@@ -1,20 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
import re
|
||||
from typing import List, Dict, Optional, Set
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .service_registry import ServiceRegistry
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,430 +12,21 @@ logger = logging.getLogger(__name__)
|
||||
class LoraScanner(ModelScanner):
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Ensure initialization happens only once
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get lora root directories"""
|
||||
return config.loras_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all LoRA directories and return metadata"""
|
||||
all_loras = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for lora_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(lora_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
loras = await task
|
||||
all_loras.extend(loras)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_loras
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for LoRA files"""
|
||||
loras = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
# Use original path instead of real path
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, loras)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return loras
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
loras.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, first_letter: str = None) -> Dict:
|
||||
"""Get paginated and filtered lora data
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort method ('name' or 'date')
|
||||
folder: Filter by folder path
|
||||
search: Search term
|
||||
fuzzy_search: Use fuzzy matching for search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
||||
favorites_only: Filter for favorite models only
|
||||
first_letter: Filter by first letter of model name
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply first letter filtering
|
||||
if first_letter:
|
||||
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if any(tag in lora.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
search_opts = search_options or {}
|
||||
|
||||
for lora in filtered_data:
|
||||
# Search by file name
|
||||
if search_opts.get('filename', True):
|
||||
if fuzzy_match(lora.get('file_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_opts.get('modelname', True):
|
||||
if fuzzy_match(lora.get('model_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_opts.get('tags', False) and 'tags' in lora:
|
||||
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _filter_by_first_letter(self, data, letter):
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char):
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
async def get_letter_counts(self):
|
||||
"""Get count of models for each letter of the alphabet"""
|
||||
cache = await self.get_cached_data()
|
||||
data = cache.sorted_by_name
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
|
||||
"""Update file paths in metadata file"""
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Update file_path
|
||||
metadata['file_path'] = lora_path.replace(os.sep, '/')
|
||||
|
||||
# Update preview_url if exists
|
||||
if 'preview_url' in metadata:
|
||||
preview_dir = os.path.dirname(lora_path)
|
||||
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
||||
preview_ext = os.path.splitext(metadata['preview_url'])[1]
|
||||
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
||||
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
|
||||
|
||||
# Lora-specific hash index functionality
|
||||
def has_lora_hash(self, sha256: str) -> bool:
|
||||
"""Check if a LoRA with given hash exists"""
|
||||
return self.has_hash(sha256)
|
||||
|
||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a LoRA by its hash"""
|
||||
return self.get_path_by_hash(sha256)
|
||||
|
||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self.get_hash_by_path(file_path)
|
||||
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count"""
|
||||
# Make sure cache is initialized
|
||||
await self.get_cached_data()
|
||||
|
||||
# Sort tags by count in descending order
|
||||
sorted_tags = sorted(
|
||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||
key=lambda x: x['count'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Return limited number
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get base models used in loras sorted by frequency"""
|
||||
# Make sure cache is initialized
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Count base model occurrences
|
||||
base_model_counts = {}
|
||||
for lora in cache.raw_data:
|
||||
if 'base_model' in lora and lora['base_model']:
|
||||
base_model = lora['base_model']
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
# Sort base models by count
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
# Return limited number
|
||||
return sorted_models[:limit]
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
@@ -482,19 +63,3 @@ class LoraScanner(ModelScanner):
|
||||
test_hash_result = self._hash_index.get_hash(test_path)
|
||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||
|
||||
async def get_lora_info_by_name(self, name):
|
||||
"""Get LoRA information by name"""
|
||||
try:
|
||||
# Get cached data
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Find the LoRA by name
|
||||
for lora in cache.raw_data:
|
||||
if lora.get("file_name") == name:
|
||||
return lora
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
181
py/services/lora_service.py
Normal file
181
py/services/lora_service.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize LoRA service
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
"from_civitai": lora_data.get("from_civitai", True),
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
data = cache.raw_data
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
file_path = lora.get('file_path', '')
|
||||
if file_path:
|
||||
# Convert to forward slashes and extract relative path
|
||||
file_path_normalized = file_path.replace('\\', '/')
|
||||
relative_path = relative_path.replace('\\', '/')
|
||||
# Find the relative path part by looking for the relative_path in the full path
|
||||
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
|
||||
return lora.get('usage_tips', '')
|
||||
|
||||
return None
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
151
py/services/metadata_archive_manager.py
Normal file
151
py/services/metadata_archive_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import zipfile
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from .downloader import get_downloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetadataArchiveManager:
|
||||
"""Manages downloading and extracting Civitai metadata archive database"""
|
||||
|
||||
DOWNLOAD_URLS = [
|
||||
"https://github.com/willmiao/civitai-metadata-archive-db/releases/download/db-2025-08-08/civitai.zip",
|
||||
"https://huggingface.co/datasets/willmiao/civitai-metadata-archive-db/blob/main/civitai.zip"
|
||||
]
|
||||
|
||||
def __init__(self, base_path: str):
|
||||
"""Initialize with base path where files will be stored"""
|
||||
self.base_path = Path(base_path)
|
||||
self.civitai_folder = self.base_path / "civitai"
|
||||
self.archive_path = self.base_path / "civitai.zip"
|
||||
self.db_path = self.civitai_folder / "civitai.sqlite"
|
||||
|
||||
def is_database_available(self) -> bool:
|
||||
"""Check if the SQLite database is available and valid"""
|
||||
return self.db_path.exists() and self.db_path.stat().st_size > 0
|
||||
|
||||
def get_database_path(self) -> Optional[str]:
|
||||
"""Get the path to the SQLite database if available"""
|
||||
if self.is_database_available():
|
||||
return str(self.db_path)
|
||||
return None
|
||||
|
||||
async def download_and_extract_database(self, progress_callback=None) -> bool:
|
||||
"""Download and extract the metadata archive database
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback function to report progress
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create directories if they don't exist
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
self.civitai_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download the archive
|
||||
if not await self._download_archive(progress_callback):
|
||||
return False
|
||||
|
||||
# Extract the archive
|
||||
if not await self._extract_archive(progress_callback):
|
||||
return False
|
||||
|
||||
# Clean up the archive file
|
||||
if self.archive_path.exists():
|
||||
self.archive_path.unlink()
|
||||
|
||||
logger.info(f"Successfully downloaded and extracted metadata database to {self.db_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading and extracting metadata database: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _download_archive(self, progress_callback=None) -> bool:
|
||||
"""Download the zip archive from one of the available URLs"""
|
||||
downloader = await get_downloader()
|
||||
|
||||
for url in self.DOWNLOAD_URLS:
|
||||
try:
|
||||
logger.info(f"Attempting to download from {url}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("download", f"Downloading from {url}")
|
||||
|
||||
# Custom progress callback to report download progress
|
||||
async def download_progress(progress):
|
||||
if progress_callback:
|
||||
progress_callback("download", f"Downloading archive... {progress:.1f}%")
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
url=url,
|
||||
save_path=str(self.archive_path),
|
||||
progress_callback=download_progress,
|
||||
use_auth=False, # Public download, no auth needed
|
||||
allow_resume=True
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully downloaded archive from {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to download from {url}: {result}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error downloading from {url}: {e}")
|
||||
continue
|
||||
|
||||
logger.error("Failed to download archive from any URL")
|
||||
return False
|
||||
|
||||
async def _extract_archive(self, progress_callback=None) -> bool:
|
||||
"""Extract the zip archive to the civitai folder"""
|
||||
try:
|
||||
if progress_callback:
|
||||
progress_callback("extract", "Extracting archive...")
|
||||
|
||||
# Run extraction in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._extract_zip_sync)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("extract", "Extraction completed")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting archive: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _extract_zip_sync(self):
|
||||
"""Synchronous zip extraction (runs in thread pool)"""
|
||||
with zipfile.ZipFile(self.archive_path, 'r') as archive:
|
||||
archive.extractall(path=self.base_path)
|
||||
|
||||
async def remove_database(self) -> bool:
|
||||
"""Remove the metadata database and folder"""
|
||||
try:
|
||||
if self.civitai_folder.exists():
|
||||
# Remove all files in the civitai folder
|
||||
for file_path in self.civitai_folder.iterdir():
|
||||
if file_path.is_file():
|
||||
file_path.unlink()
|
||||
|
||||
# Remove the folder itself
|
||||
self.civitai_folder.rmdir()
|
||||
|
||||
# Also remove the archive file if it exists
|
||||
if self.archive_path.exists():
|
||||
self.archive_path.unlink()
|
||||
|
||||
logger.info("Successfully removed metadata database")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing metadata database: {e}", exc_info=True)
|
||||
return False
|
||||
117
py/services/metadata_service.py
Normal file
117
py/services/metadata_service.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import logging
|
||||
from .model_metadata_provider import (
|
||||
ModelMetadataProviderManager,
|
||||
SQLiteModelMetadataProvider,
|
||||
CivitaiModelMetadataProvider,
|
||||
FallbackMetadataProvider
|
||||
)
|
||||
from .settings_manager import settings
|
||||
from .metadata_archive_manager import MetadataArchiveManager
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def initialize_metadata_providers():
|
||||
"""Initialize and configure all metadata providers based on settings"""
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
# Clear existing providers to allow reinitialization
|
||||
provider_manager.providers.clear()
|
||||
provider_manager.default_provider = None
|
||||
|
||||
# Get settings
|
||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
|
||||
providers = []
|
||||
|
||||
# Initialize archive database provider if enabled
|
||||
if enable_archive_db:
|
||||
try:
|
||||
# Initialize archive manager
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
archive_manager = MetadataArchiveManager(base_path)
|
||||
|
||||
db_path = archive_manager.get_database_path()
|
||||
if db_path and os.path.exists(db_path):
|
||||
sqlite_provider = SQLiteModelMetadataProvider(db_path)
|
||||
provider_manager.register_provider('sqlite', sqlite_provider)
|
||||
providers.append(('sqlite', sqlite_provider))
|
||||
logger.info(f"SQLite metadata provider registered with database: {db_path}")
|
||||
else:
|
||||
logger.warning("Metadata archive database is enabled but database file not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SQLite metadata provider: {e}")
|
||||
|
||||
# Initialize Civitai API provider (always available as fallback)
|
||||
try:
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
civitai_provider = CivitaiModelMetadataProvider(civitai_client)
|
||||
provider_manager.register_provider('civitai_api', civitai_provider)
|
||||
providers.append(('civitai_api', civitai_provider))
|
||||
logger.debug("Civitai API metadata provider registered")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
|
||||
|
||||
# Register CivArchive provider, but do NOT add to fallback providers
|
||||
try:
|
||||
from .model_metadata_provider import CivArchiveModelMetadataProvider
|
||||
civarchive_provider = CivArchiveModelMetadataProvider()
|
||||
provider_manager.register_provider('civarchive', civarchive_provider)
|
||||
logger.debug("CivArchive metadata provider registered (not included in fallback)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
|
||||
|
||||
# Set up fallback provider based on available providers
|
||||
if len(providers) > 1:
|
||||
# Always use Civitai API first, then Archive DB
|
||||
ordered_providers = []
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api'])
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite'])
|
||||
|
||||
if ordered_providers:
|
||||
fallback_provider = FallbackMetadataProvider(ordered_providers)
|
||||
provider_manager.register_provider('fallback', fallback_provider, is_default=True)
|
||||
logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, Civitai API first")
|
||||
elif len(providers) == 1:
|
||||
# Only one provider available, set it as default
|
||||
provider_name, provider = providers[0]
|
||||
provider_manager.register_provider(provider_name, provider, is_default=True)
|
||||
logger.debug(f"Single metadata provider registered as default: {provider_name}")
|
||||
else:
|
||||
logger.warning("No metadata providers available - this may cause metadata lookup failures")
|
||||
|
||||
return provider_manager
|
||||
|
||||
async def update_metadata_providers():
|
||||
"""Update metadata providers based on current settings"""
|
||||
try:
|
||||
# Get current settings
|
||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
|
||||
# Reinitialize all providers with new settings
|
||||
provider_manager = await initialize_metadata_providers()
|
||||
|
||||
logger.info(f"Updated metadata providers, archive_db enabled: {enable_archive_db}")
|
||||
return provider_manager
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata providers: {e}")
|
||||
return await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
async def get_metadata_archive_manager():
|
||||
"""Get metadata archive manager instance"""
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
return MetadataArchiveManager(base_path)
|
||||
|
||||
async def get_metadata_provider(provider_name: str = None):
|
||||
"""Get a specific metadata provider or default provider"""
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
|
||||
if provider_name:
|
||||
return provider_manager._get_provider(provider_name)
|
||||
|
||||
return provider_manager._get_provider()
|
||||
|
||||
async def get_default_metadata_provider():
|
||||
"""Get the default metadata provider (fallback or single provider)"""
|
||||
return await get_metadata_provider()
|
||||
355
py/services/metadata_sync_service.py
Normal file
355
py/services/metadata_sync_service.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Services for synchronising metadata with remote providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.model_utils import determine_base_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataProviderProtocol:
|
||||
"""Subset of metadata provider interface consumed by the sync service."""
|
||||
|
||||
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
...
|
||||
|
||||
async def get_model_version(
|
||||
self, model_id: int, model_version_id: Optional[int]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
|
||||
class MetadataSyncService:
|
||||
"""High level orchestration for metadata synchronisation flows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
preview_service,
|
||||
settings: SettingsManager,
|
||||
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
|
||||
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._preview_service = preview_service
|
||||
self._settings = settings
|
||||
self._get_default_provider = default_metadata_provider_factory
|
||||
self._get_provider = metadata_provider_selector
|
||||
|
||||
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
|
||||
"""Load metadata JSON from disk, returning an empty structure when missing."""
|
||||
|
||||
if not os.path.exists(metadata_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||
return json.load(handle)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
|
||||
return {}
|
||||
|
||||
async def mark_not_found_on_civitai(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist the not-found flag for a metadata payload."""
|
||||
|
||||
local_metadata["from_civitai"] = False
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
|
||||
@staticmethod
|
||||
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
|
||||
"""Determine if the metadata originated from the CivitAI public API."""
|
||||
|
||||
if not isinstance(meta, dict):
|
||||
return False
|
||||
files = meta.get("files")
|
||||
images = meta.get("images")
|
||||
source = meta.get("source")
|
||||
return bool(files) and bool(images) and source != "archive_db"
|
||||
|
||||
async def update_model_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
civitai_metadata: Dict[str, Any],
|
||||
metadata_provider: Optional[MetadataProviderProtocol] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge remote metadata into the local record and persist the result."""
|
||||
|
||||
existing_civitai = local_metadata.get("civitai") or {}
|
||||
|
||||
if (
|
||||
civitai_metadata.get("source") == "archive_db"
|
||||
and self.is_civitai_api_metadata(existing_civitai)
|
||||
):
|
||||
logger.info(
|
||||
"Skip civitai update for %s (%s)",
|
||||
local_metadata.get("model_name", ""),
|
||||
existing_civitai.get("name", ""),
|
||||
)
|
||||
else:
|
||||
merged_civitai = existing_civitai.copy()
|
||||
merged_civitai.update(civitai_metadata)
|
||||
|
||||
if civitai_metadata.get("source") == "archive_db":
|
||||
model_name = civitai_metadata.get("model", {}).get("name", "")
|
||||
version_name = civitai_metadata.get("name", "")
|
||||
logger.info(
|
||||
"Recovered metadata from archive_db for deleted model: %s (%s)",
|
||||
model_name,
|
||||
version_name,
|
||||
)
|
||||
|
||||
if "trainedWords" in existing_civitai:
|
||||
existing_trained = existing_civitai.get("trainedWords", [])
|
||||
new_trained = civitai_metadata.get("trainedWords", [])
|
||||
merged_trained = list(set(existing_trained + new_trained))
|
||||
merged_civitai["trainedWords"] = merged_trained
|
||||
|
||||
local_metadata["civitai"] = merged_civitai
|
||||
|
||||
if "model" in civitai_metadata and civitai_metadata["model"]:
|
||||
model_data = civitai_metadata["model"]
|
||||
|
||||
if model_data.get("name"):
|
||||
local_metadata["model_name"] = model_data["name"]
|
||||
|
||||
if not local_metadata.get("modelDescription") and model_data.get("description"):
|
||||
local_metadata["modelDescription"] = model_data["description"]
|
||||
|
||||
if not local_metadata.get("tags") and model_data.get("tags"):
|
||||
local_metadata["tags"] = model_data["tags"]
|
||||
|
||||
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
|
||||
"creator"
|
||||
):
|
||||
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||
|
||||
local_metadata["base_model"] = determine_base_model(
|
||||
civitai_metadata.get("baseModel")
|
||||
)
|
||||
|
||||
await self._preview_service.ensure_preview_for_metadata(
|
||||
metadata_path, local_metadata, civitai_metadata.get("images", [])
|
||||
)
|
||||
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
return local_metadata
|
||||
|
||||
async def fetch_and_update_model(
|
||||
self,
|
||||
*,
|
||||
sha256: str,
|
||||
file_path: str,
|
||||
model_data: Dict[str, Any],
|
||||
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Fetch metadata for a model and update both disk and cache state."""
|
||||
|
||||
if not isinstance(model_data, dict):
|
||||
error = f"Invalid model_data type: {type(model_data)}"
|
||||
logger.error(error)
|
||||
return False, error
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||
|
||||
try:
|
||||
if model_data.get("civitai_deleted") is True:
|
||||
if not enable_archive or model_data.get("db_checked") is True:
|
||||
return (
|
||||
False,
|
||||
"CivitAI model is deleted and metadata archive DB is not enabled",
|
||||
)
|
||||
metadata_provider = await self._get_provider("sqlite")
|
||||
else:
|
||||
metadata_provider = await self._get_default_provider()
|
||||
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
if error == "Model not found":
|
||||
model_data["from_civitai"] = False
|
||||
model_data["civitai_deleted"] = True
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
data_to_save = model_data.copy()
|
||||
data_to_save.pop("folder", None)
|
||||
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||
|
||||
error_msg = (
|
||||
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
model_data["from_civitai"] = True
|
||||
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
local_metadata = model_data.copy()
|
||||
local_metadata.pop("folder", None)
|
||||
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
local_metadata,
|
||||
civitai_metadata,
|
||||
metadata_provider,
|
||||
)
|
||||
|
||||
update_payload = {
|
||||
"model_name": local_metadata.get("model_name"),
|
||||
"preview_url": local_metadata.get("preview_url"),
|
||||
"civitai": local_metadata.get("civitai"),
|
||||
}
|
||||
model_data.update(update_payload)
|
||||
|
||||
await update_cache_func(file_path, file_path, local_metadata)
|
||||
return True, None
|
||||
except KeyError as exc:
|
||||
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
error_msg = f"Error fetching metadata: {exc}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
async def fetch_metadata_by_sha(
|
||||
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Fetch metadata for a SHA256 hash from the configured provider."""
|
||||
|
||||
provider = metadata_provider or await self._get_default_provider()
|
||||
return await provider.get_model_by_hash(sha256)
|
||||
|
||||
async def relink_metadata(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
model_id: int,
|
||||
model_version_id: Optional[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Relink a local metadata record to a specific CivitAI model version."""
|
||||
|
||||
provider = await self._get_default_provider()
|
||||
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
|
||||
if not civitai_metadata:
|
||||
raise ValueError(
|
||||
f"Model version not found on CivitAI for ID: {model_id}"
|
||||
+ (f" with version: {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
|
||||
primary_model_file: Optional[Dict[str, Any]] = None
|
||||
for file_info in civitai_metadata.get("files", []):
|
||||
if file_info.get("primary", False) and file_info.get("type") == "Model":
|
||||
primary_model_file = file_info
|
||||
break
|
||||
|
||||
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
|
||||
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
metadata,
|
||||
civitai_metadata,
|
||||
provider,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
async def save_metadata_updates(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
updates: Dict[str, Any],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply metadata updates and persist to disk and cache."""
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
|
||||
metadata[key].update(value)
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
await update_cache(file_path, file_path, metadata)
|
||||
|
||||
if "model_name" in updates:
|
||||
logger.debug("Metadata update touched model_name; cache resort required")
|
||||
|
||||
return metadata
|
||||
|
||||
async def verify_duplicate_hashes(
|
||||
self,
|
||||
*,
|
||||
file_paths: Iterable[str],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
hash_calculator: Callable[[str], Awaitable[str]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Verify a collection of files share the same SHA256 hash."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for verification")
|
||||
|
||||
results = {
|
||||
"verified_as_duplicates": True,
|
||||
"mismatched_files": [],
|
||||
"new_hash_map": {},
|
||||
}
|
||||
|
||||
expected_hash: Optional[str] = None
|
||||
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
|
||||
first_metadata = await metadata_loader(first_metadata_path)
|
||||
if first_metadata and "sha256" in first_metadata:
|
||||
expected_hash = first_metadata["sha256"].lower()
|
||||
|
||||
for path in file_paths:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
actual_hash = await hash_calculator(path)
|
||||
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
stored_hash = metadata.get("sha256", "").lower()
|
||||
|
||||
if not expected_hash:
|
||||
expected_hash = stored_hash
|
||||
|
||||
if actual_hash != expected_hash:
|
||||
results["verified_as_duplicates"] = False
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = actual_hash
|
||||
|
||||
if actual_hash != stored_hash:
|
||||
metadata["sha256"] = actual_hash
|
||||
await self._metadata_manager.save_metadata(path, metadata)
|
||||
await update_cache(path, path, metadata)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error verifying hash for %s: %s", path, exc)
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = "error_calculating_hash"
|
||||
results["verified_as_duplicates"] = False
|
||||
|
||||
return results
|
||||
|
||||
@@ -1,43 +1,92 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
# Supported sort modes: (sort_key, order)
|
||||
# order: 'asc' for ascending, 'desc' for descending
|
||||
SUPPORTED_SORT_MODES = [
|
||||
('name', 'asc'),
|
||||
('name', 'desc'),
|
||||
('date', 'asc'),
|
||||
('date', 'desc'),
|
||||
('size', 'asc'),
|
||||
('size', 'desc'),
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data"""
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
self._last_sort: Tuple[str, str] = (None, None)
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
if self._last_sort != (None, None):
|
||||
sort_key, order = self._last_sort
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
# Update folder list
|
||||
# else: do nothing
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by model_name, case-insensitive
|
||||
return natsorted(
|
||||
data,
|
||||
key=lambda x: x['model_name'].lower(),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
reverse=reverse
|
||||
)
|
||||
else:
|
||||
# Fallback: no sort
|
||||
return list(data)
|
||||
|
||||
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||
async with self._lock:
|
||||
if (sort_key, order) == self._last_sort:
|
||||
return self._last_sorted_data
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sort = (sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
return sorted_data
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||
"""Update preview_url for a specific model in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the model to update
|
||||
preview_url: The new preview URL
|
||||
preview_nsfw_level: The NSFW level of the preview
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the model wasn't found
|
||||
@@ -47,19 +96,9 @@ class ModelCache:
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
item['preview_nsfw_level'] = preview_nsfw_level
|
||||
break
|
||||
else:
|
||||
return False # Model not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
463
py/services/model_file_service.py
Normal file
463
py/services/model_file_service.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any, Set
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
|
||||
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProgressCallback(ABC):
|
||||
"""Abstract callback interface for progress reporting"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Called when progress is updated"""
|
||||
pass
|
||||
|
||||
|
||||
class AutoOrganizeResult:
|
||||
"""Result object for auto-organize operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.total: int = 0
|
||||
self.processed: int = 0
|
||||
self.success_count: int = 0
|
||||
self.failure_count: int = 0
|
||||
self.skipped_count: int = 0
|
||||
self.operation_type: str = 'unknown'
|
||||
self.cleanup_counts: Dict[str, int] = {}
|
||||
self.results: List[Dict[str, Any]] = []
|
||||
self.results_truncated: bool = False
|
||||
self.sample_results: List[Dict[str, Any]] = []
|
||||
self.is_flat_structure: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert result to dictionary"""
|
||||
result = {
|
||||
'success': True,
|
||||
'message': f'Auto-organize {self.operation_type} completed: {self.success_count} moved, {self.skipped_count} skipped, {self.failure_count} failed out of {self.total} total',
|
||||
'summary': {
|
||||
'total': self.total,
|
||||
'success': self.success_count,
|
||||
'skipped': self.skipped_count,
|
||||
'failures': self.failure_count,
|
||||
'organization_type': 'flat' if self.is_flat_structure else 'structured',
|
||||
'cleaned_dirs': self.cleanup_counts,
|
||||
'operation_type': self.operation_type
|
||||
}
|
||||
}
|
||||
|
||||
if self.results_truncated:
|
||||
result['results_truncated'] = True
|
||||
result['sample_results'] = self.sample_results
|
||||
else:
|
||||
result['results'] = self.results
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ModelFileService:
|
||||
"""Service for handling model file operations and organization"""
|
||||
|
||||
def __init__(self, scanner, model_type: str):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
scanner: Model scanner instance
|
||||
model_type: Type of model (e.g., 'lora', 'checkpoint')
|
||||
"""
|
||||
self.scanner = scanner
|
||||
self.model_type = model_type
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
async def auto_organize_models(
|
||||
self,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
progress_callback: Optional[ProgressCallback] = None
|
||||
) -> AutoOrganizeResult:
|
||||
"""Auto-organize models based on current settings
|
||||
|
||||
Args:
|
||||
file_paths: Optional list of specific file paths to organize.
|
||||
If None, organizes all models.
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
AutoOrganizeResult object with operation results
|
||||
"""
|
||||
result = AutoOrganizeResult()
|
||||
source_directories: Set[str] = set()
|
||||
|
||||
try:
|
||||
# Get all models from cache
|
||||
cache = await self.scanner.get_cached_data()
|
||||
all_models = cache.raw_data
|
||||
|
||||
# Filter models if specific file paths are provided
|
||||
if file_paths:
|
||||
all_models = [model for model in all_models if model.get('file_path') in file_paths]
|
||||
result.operation_type = 'bulk'
|
||||
else:
|
||||
result.operation_type = 'all'
|
||||
|
||||
# Get model roots for this scanner
|
||||
model_roots = self.get_model_roots()
|
||||
if not model_roots:
|
||||
raise ValueError('No model roots configured')
|
||||
|
||||
# Check if flat structure is configured for this model type
|
||||
path_template = settings.get_download_path_template(self.model_type)
|
||||
result.is_flat_structure = not path_template
|
||||
|
||||
# Initialize tracking
|
||||
result.total = len(all_models)
|
||||
|
||||
# Send initial progress
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'started',
|
||||
'total': result.total,
|
||||
'processed': 0,
|
||||
'success': 0,
|
||||
'failures': 0,
|
||||
'skipped': 0,
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
# Process models in batches
|
||||
await self._process_models_in_batches(
|
||||
all_models,
|
||||
model_roots,
|
||||
result,
|
||||
progress_callback,
|
||||
source_directories # Pass the set to track source directories
|
||||
)
|
||||
|
||||
# Send cleanup progress
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'cleaning',
|
||||
'total': result.total,
|
||||
'processed': result.processed,
|
||||
'success': result.success_count,
|
||||
'failures': result.failure_count,
|
||||
'skipped': result.skipped_count,
|
||||
'message': 'Cleaning up empty directories...',
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
# Clean up empty directories - only in affected directories for bulk operations
|
||||
cleanup_paths = list(source_directories) if result.operation_type == 'bulk' else model_roots
|
||||
result.cleanup_counts = await self._cleanup_empty_directories(cleanup_paths)
|
||||
|
||||
# Send completion message
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'completed',
|
||||
'total': result.total,
|
||||
'processed': result.processed,
|
||||
'success': result.success_count,
|
||||
'failures': result.failure_count,
|
||||
'skipped': result.skipped_count,
|
||||
'cleanup': result.cleanup_counts,
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto_organize_models: {e}", exc_info=True)
|
||||
|
||||
# Send error message
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'error',
|
||||
'error': str(e),
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
raise e
|
||||
|
||||
async def _process_models_in_batches(
|
||||
self,
|
||||
all_models: List[Dict[str, Any]],
|
||||
model_roots: List[str],
|
||||
result: AutoOrganizeResult,
|
||||
progress_callback: Optional[ProgressCallback],
|
||||
source_directories: Optional[Set[str]] = None
|
||||
) -> None:
|
||||
"""Process models in batches to avoid overwhelming the system"""
|
||||
|
||||
for i in range(0, result.total, AUTO_ORGANIZE_BATCH_SIZE):
|
||||
batch = all_models[i:i + AUTO_ORGANIZE_BATCH_SIZE]
|
||||
|
||||
for model in batch:
|
||||
await self._process_single_model(model, model_roots, result, source_directories)
|
||||
result.processed += 1
|
||||
|
||||
# Send progress update after each batch
|
||||
if progress_callback:
|
||||
await progress_callback.on_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'processing',
|
||||
'total': result.total,
|
||||
'processed': result.processed,
|
||||
'success': result.success_count,
|
||||
'failures': result.failure_count,
|
||||
'skipped': result.skipped_count,
|
||||
'operation_type': result.operation_type
|
||||
})
|
||||
|
||||
# Small delay between batches
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _process_single_model(
|
||||
self,
|
||||
model: Dict[str, Any],
|
||||
model_roots: List[str],
|
||||
result: AutoOrganizeResult,
|
||||
source_directories: Optional[Set[str]] = None
|
||||
) -> None:
|
||||
"""Process a single model for organization"""
|
||||
try:
|
||||
file_path = model.get('file_path')
|
||||
model_name = model.get('model_name', 'Unknown')
|
||||
|
||||
if not file_path:
|
||||
self._add_result(result, model_name, False, "No file path found")
|
||||
result.failure_count += 1
|
||||
return
|
||||
|
||||
# Find which model root this file belongs to
|
||||
current_root = self._find_model_root(file_path, model_roots)
|
||||
if not current_root:
|
||||
self._add_result(result, model_name, False,
|
||||
"Model file not found in any configured root directory")
|
||||
result.failure_count += 1
|
||||
return
|
||||
|
||||
# Determine target directory
|
||||
target_dir = await self._calculate_target_directory(
|
||||
model, current_root, result.is_flat_structure
|
||||
)
|
||||
|
||||
if target_dir is None:
|
||||
self._add_result(result, model_name, False,
|
||||
"Skipped - insufficient metadata for organization")
|
||||
result.skipped_count += 1
|
||||
return
|
||||
|
||||
current_dir = os.path.dirname(file_path)
|
||||
|
||||
# Skip if already in correct location
|
||||
if current_dir.replace(os.sep, '/') == target_dir.replace(os.sep, '/'):
|
||||
result.skipped_count += 1
|
||||
return
|
||||
|
||||
# Check for conflicts
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_dir, file_name)
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
self._add_result(result, model_name, False,
|
||||
f"Target file already exists: {target_file_path}")
|
||||
result.failure_count += 1
|
||||
return
|
||||
|
||||
# Store the source directory for potential cleanup
|
||||
if source_directories is not None:
|
||||
source_directories.add(current_dir)
|
||||
|
||||
# Perform the move
|
||||
success = await self.scanner.move_model(file_path, target_dir)
|
||||
|
||||
if success:
|
||||
result.success_count += 1
|
||||
else:
|
||||
self._add_result(result, model_name, False, "Failed to move model")
|
||||
result.failure_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing model {model.get('model_name', 'Unknown')}: {e}", exc_info=True)
|
||||
self._add_result(result, model.get('model_name', 'Unknown'), False, f"Error: {str(e)}")
|
||||
result.failure_count += 1
|
||||
|
||||
def _find_model_root(self, file_path: str, model_roots: List[str]) -> Optional[str]:
|
||||
"""Find which model root the file belongs to"""
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root).replace(os.sep, '/')
|
||||
normalized_file = os.path.normpath(file_path).replace(os.sep, '/')
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
return root
|
||||
return None
|
||||
|
||||
async def _calculate_target_directory(
|
||||
self,
|
||||
model: Dict[str, Any],
|
||||
current_root: str,
|
||||
is_flat_structure: bool
|
||||
) -> Optional[str]:
|
||||
"""Calculate the target directory for a model"""
|
||||
if is_flat_structure:
|
||||
file_path = model.get('file_path')
|
||||
current_dir = os.path.dirname(file_path)
|
||||
|
||||
# Check if already in root directory
|
||||
if os.path.normpath(current_dir) == os.path.normpath(current_root):
|
||||
return None # Signal to skip
|
||||
|
||||
return current_root
|
||||
else:
|
||||
# Calculate new relative path based on settings
|
||||
new_relative_path = calculate_relative_path_for_model(model, self.model_type)
|
||||
|
||||
if not new_relative_path:
|
||||
return None # Signal to skip
|
||||
|
||||
return os.path.join(current_root, new_relative_path).replace(os.sep, '/')
|
||||
|
||||
def _add_result(
|
||||
self,
|
||||
result: AutoOrganizeResult,
|
||||
model_name: str,
|
||||
success: bool,
|
||||
message: str
|
||||
) -> None:
|
||||
"""Add a result entry if under the limit"""
|
||||
if len(result.results) < 100: # Limit detailed results
|
||||
result.results.append({
|
||||
"model": model_name,
|
||||
"success": success,
|
||||
"message": message
|
||||
})
|
||||
elif len(result.results) == 100:
|
||||
# Mark as truncated and save sample
|
||||
result.results_truncated = True
|
||||
result.sample_results = result.results[:50]
|
||||
|
||||
async def _cleanup_empty_directories(self, paths: List[str]) -> Dict[str, int]:
|
||||
"""Clean up empty directories after organizing
|
||||
|
||||
Args:
|
||||
paths: List of paths to check for empty directories
|
||||
|
||||
Returns:
|
||||
Dictionary with counts of removed directories by root path
|
||||
"""
|
||||
cleanup_counts = {}
|
||||
for path in paths:
|
||||
removed = remove_empty_dirs(path)
|
||||
cleanup_counts[path] = removed
|
||||
return cleanup_counts
|
||||
|
||||
|
||||
class ModelMoveService:
|
||||
"""Service for handling individual model moves"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
scanner: Model scanner instance
|
||||
"""
|
||||
self.scanner = scanner
|
||||
|
||||
async def move_model(self, file_path: str, target_path: str) -> Dict[str, Any]:
|
||||
"""Move a single model file
|
||||
|
||||
Args:
|
||||
file_path: Source file path
|
||||
target_path: Target directory path
|
||||
|
||||
Returns:
|
||||
Dictionary with move result
|
||||
"""
|
||||
try:
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Source and target directories are the same',
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': file_path
|
||||
}
|
||||
|
||||
new_file_path = await self.scanner.move_model(file_path, target_path)
|
||||
if new_file_path:
|
||||
return {
|
||||
'success': True,
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': new_file_path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'error': 'Failed to move model',
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'original_file_path': file_path,
|
||||
'new_file_path': None
|
||||
}
|
||||
|
||||
async def move_models_bulk(self, file_paths: List[str], target_path: str) -> Dict[str, Any]:
|
||||
"""Move multiple model files
|
||||
|
||||
Args:
|
||||
file_paths: List of source file paths
|
||||
target_path: Target directory path
|
||||
|
||||
Returns:
|
||||
Dictionary with bulk move results
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
|
||||
for file_path in file_paths:
|
||||
result = await self.move_model(file_path, target_path)
|
||||
results.append({
|
||||
"original_file_path": file_path,
|
||||
"new_file_path": result.get('new_file_path'),
|
||||
"success": result['success'],
|
||||
"message": result.get('message', result.get('error', 'Unknown'))
|
||||
})
|
||||
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||
'results': results,
|
||||
'success_count': success_count,
|
||||
'failure_count': failure_count
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'results': [],
|
||||
'success_count': 0,
|
||||
'failure_count': len(file_paths)
|
||||
}
|
||||
@@ -31,29 +31,34 @@ class ModelHashIndex:
|
||||
if file_path not in self._duplicate_hashes.get(sha256, []):
|
||||
self._duplicate_hashes.setdefault(sha256, []).append(file_path)
|
||||
|
||||
# Track duplicates by filename
|
||||
# Track duplicates by filename - FIXED LOGIC
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash != sha256: # Different models with the same name
|
||||
old_path = self._hash_to_path.get(old_hash)
|
||||
if old_path:
|
||||
if filename not in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [old_path]
|
||||
if file_path not in self._duplicate_filenames.get(filename, []):
|
||||
self._duplicate_filenames.setdefault(filename, []).append(file_path)
|
||||
existing_hash = self._filename_to_hash[filename]
|
||||
existing_path = self._hash_to_path.get(existing_hash)
|
||||
|
||||
# If this is a different file with the same filename
|
||||
if existing_path and existing_path != file_path:
|
||||
# Initialize duplicates tracking if needed
|
||||
if filename not in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [existing_path]
|
||||
|
||||
# Add current file to duplicates if not already present
|
||||
if file_path not in self._duplicate_filenames[filename]:
|
||||
self._duplicate_filenames[filename].append(file_path)
|
||||
|
||||
# Remove old path mapping if hash exists
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
old_filename = self._get_filename_from_path(old_path)
|
||||
if old_filename in self._filename_to_hash:
|
||||
if old_filename in self._filename_to_hash and self._filename_to_hash[old_filename] == sha256:
|
||||
del self._filename_to_hash[old_filename]
|
||||
|
||||
# Remove old hash mapping if filename exists
|
||||
# Remove old hash mapping if filename exists and points to different hash
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash in self._hash_to_path:
|
||||
del self._hash_to_path[old_hash]
|
||||
if old_hash != sha256 and old_hash in self._hash_to_path:
|
||||
# Don't delete the old hash mapping, just update filename mapping
|
||||
pass
|
||||
|
||||
# Add new mappings
|
||||
self._hash_to_path[sha256] = file_path
|
||||
@@ -63,16 +68,16 @@ class ModelHashIndex:
|
||||
"""Extract filename without extension from path"""
|
||||
return os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
def remove_by_path(self, file_path: str) -> None:
|
||||
def remove_by_path(self, file_path: str, hash_val: str = None) -> None:
|
||||
"""Remove entry by file path"""
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
hash_val = None
|
||||
|
||||
# Find the hash for this file path
|
||||
for h, p in self._hash_to_path.items():
|
||||
if p == file_path:
|
||||
hash_val = h
|
||||
break
|
||||
if hash_val is None:
|
||||
for h, p in self._hash_to_path.items():
|
||||
if p == file_path:
|
||||
hash_val = h
|
||||
break
|
||||
|
||||
# If we didn't find a hash, nothing to do
|
||||
if not hash_val:
|
||||
@@ -199,8 +204,6 @@ class ModelHashIndex:
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a filename without extension"""
|
||||
# Strip extension if present to make the function more flexible
|
||||
filename = os.path.splitext(filename)[0]
|
||||
return self._filename_to_hash.get(filename)
|
||||
|
||||
def clear(self) -> None:
|
||||
@@ -219,7 +222,7 @@ class ModelHashIndex:
|
||||
return set(self._filename_to_hash.keys())
|
||||
|
||||
def get_duplicate_hashes(self) -> Dict[str, List[str]]:
|
||||
"""Get dictionary of duplicate hashes and their paths"""
|
||||
"""Get dictionary of duplicate hashes and their paths"""
|
||||
return self._duplicate_hashes
|
||||
|
||||
def get_duplicate_filenames(self) -> Dict[str, List[str]]:
|
||||
|
||||
245
py/services/model_lifecycle_service.py
Normal file
245
py/services/model_lifecycle_service.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Service routines for model lifecycle mutations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete the primary model artefacts within ``target_dir``."""
|
||||
|
||||
patterns = [
|
||||
f"{file_name}.safetensors",
|
||||
f"{file_name}.metadata.json",
|
||||
]
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{file_name}{ext}")
|
||||
|
||||
deleted: List[str] = []
|
||||
main_file = patterns[0]
|
||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
||||
|
||||
if os.path.exists(main_path):
|
||||
os.remove(main_path)
|
||||
deleted.append(main_path)
|
||||
else:
|
||||
logger.warning("Model file not found: %s", main_file)
|
||||
|
||||
for pattern in patterns[1:]:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted.append(pattern)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.warning("Failed to delete %s: %s", pattern, exc)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
class ModelLifecycleService:
|
||||
"""Co-ordinate destructive and mutating model operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scanner,
|
||||
metadata_manager,
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||
) -> None:
|
||||
self._scanner = scanner
|
||||
self._metadata_manager = metadata_manager
|
||||
self._metadata_loader = metadata_loader
|
||||
self._recipe_scanner_factory = (
|
||||
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||
)
|
||||
|
||||
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Delete a model file and associated artefacts."""
|
||||
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||
await cache.resort()
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
return {"success": True, "deleted_files": deleted_files}
|
||||
|
||||
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Mark a model as excluded and prune cache references."""
|
||||
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await self._metadata_loader(metadata_path)
|
||||
metadata["exclude"] = True
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
model_to_remove = next(
|
||||
(item for item in cache.raw_data if item["file_path"] == file_path),
|
||||
None,
|
||||
)
|
||||
|
||||
if model_to_remove:
|
||||
for tag in model_to_remove.get("tags", []):
|
||||
if tag in getattr(self._scanner, "_tags_count", {}):
|
||||
self._scanner._tags_count[tag] = max(
|
||||
0, self._scanner._tags_count[tag] - 1
|
||||
)
|
||||
if self._scanner._tags_count[tag] == 0:
|
||||
del self._scanner._tags_count[tag]
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data if item["file_path"] != file_path
|
||||
]
|
||||
await cache.resort()
|
||||
|
||||
excluded = getattr(self._scanner, "_excluded_models", None)
|
||||
if isinstance(excluded, list):
|
||||
excluded.append(file_path)
|
||||
|
||||
message = f"Model {os.path.basename(file_path)} excluded"
|
||||
return {"success": True, "message": message}
|
||||
|
||||
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
||||
"""Delete a collection of models via the scanner bulk operation."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for deletion")
|
||||
|
||||
return await self._scanner.bulk_delete_models(file_paths)
|
||||
|
||||
async def rename_model(
|
||||
self, *, file_path: str, new_file_name: str
|
||||
) -> Dict[str, object]:
|
||||
"""Rename a model and its companion artefacts."""
|
||||
|
||||
if not file_path or not new_file_name:
|
||||
raise ValueError("File path and new file name are required")
|
||||
|
||||
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
||||
if any(char in new_file_name for char in invalid_chars):
|
||||
raise ValueError("Invalid characters in file name")
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
|
||||
if os.path.exists(new_file_path):
|
||||
raise ValueError("A file with this name already exists")
|
||||
|
||||
patterns = [
|
||||
f"{old_file_name}.safetensors",
|
||||
f"{old_file_name}.metadata.json",
|
||||
f"{old_file_name}.metadata.json.bak",
|
||||
]
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{old_file_name}{ext}")
|
||||
|
||||
existing_files: List[tuple[str, str]] = []
|
||||
for pattern in patterns:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
existing_files.append((path, pattern))
|
||||
|
||||
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||
metadata: Optional[Dict[str, object]] = None
|
||||
hash_value: Optional[str] = None
|
||||
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = await self._metadata_loader(metadata_path)
|
||||
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
||||
|
||||
renamed_files: List[str] = []
|
||||
new_metadata_path: Optional[str] = None
|
||||
new_preview: Optional[str] = None
|
||||
|
||||
for old_path, pattern in existing_files:
|
||||
ext = self._get_multipart_ext(pattern)
|
||||
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
os.rename(old_path, new_path)
|
||||
renamed_files.append(new_path)
|
||||
|
||||
if ext == ".metadata.json":
|
||||
new_metadata_path = new_path
|
||||
|
||||
if metadata and new_metadata_path:
|
||||
metadata["file_name"] = new_file_name
|
||||
metadata["file_path"] = new_file_path
|
||||
|
||||
if metadata.get("preview_url"):
|
||||
old_preview = str(metadata["preview_url"])
|
||||
ext = self._get_multipart_ext(old_preview)
|
||||
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
metadata["preview_url"] = new_preview
|
||||
|
||||
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
||||
|
||||
if metadata:
|
||||
await self._scanner.update_single_model_cache(
|
||||
file_path, new_file_path, metadata
|
||||
)
|
||||
|
||||
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
||||
recipe_scanner = await self._recipe_scanner_factory()
|
||||
if recipe_scanner:
|
||||
try:
|
||||
await recipe_scanner.update_lora_filename_by_hash(
|
||||
hash_value, new_file_name
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"Error updating recipe references for %s: %s",
|
||||
file_path,
|
||||
exc,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"new_file_path": new_file_path,
|
||||
"new_preview_path": new_preview,
|
||||
"renamed_files": renamed_files,
|
||||
"reload_required": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_multipart_ext(filename: str) -> str:
|
||||
"""Return the extension for files with compound suffixes."""
|
||||
|
||||
parts = filename.split(".")
|
||||
if len(parts) == 3:
|
||||
return "." + ".".join(parts[-2:])
|
||||
if len(parts) >= 4:
|
||||
return "." + ".".join(parts[-3:])
|
||||
return os.path.splitext(filename)[1]
|
||||
|
||||
495
py/services/model_metadata_provider.py
Normal file
495
py/services/model_metadata_provider.py
Normal file
@@ -0,0 +1,495 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError as exc:
|
||||
BeautifulSoup = None # type: ignore[assignment]
|
||||
_BS4_IMPORT_ERROR = exc
|
||||
else:
|
||||
_BS4_IMPORT_ERROR = None
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
except ImportError as exc:
|
||||
aiosqlite = None # type: ignore[assignment]
|
||||
_AIOSQLITE_IMPORT_ERROR = exc
|
||||
else:
|
||||
_AIOSQLITE_IMPORT_ERROR = None
|
||||
|
||||
def _require_beautifulsoup() -> Any:
|
||||
if BeautifulSoup is None:
|
||||
raise RuntimeError(
|
||||
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
|
||||
"Install it with 'pip install beautifulsoup4'."
|
||||
) from _BS4_IMPORT_ERROR
|
||||
return BeautifulSoup
|
||||
|
||||
def _require_aiosqlite() -> Any:
|
||||
if aiosqlite is None:
|
||||
raise RuntimeError(
|
||||
"aiosqlite is required for SQLiteModelMetadataProvider. "
|
||||
"Install it with 'pip install aiosqlite'."
|
||||
) from _AIOSQLITE_IMPORT_ERROR
|
||||
return aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelMetadataProvider(ABC):
|
||||
"""Base abstract class for all model metadata providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model with their details"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata"""
|
||||
pass
|
||||
|
||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses Civitai API for metadata"""
|
||||
|
||||
def __init__(self, civitai_client):
|
||||
self.client = civitai_client
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
return await self.client.get_model_versions(model_id)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
return await self.client.get_model_version(model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None, "CivArchive provider does not support hash lookup"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version by parsing CivArchive HTML page"""
|
||||
if model_id is None or version_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Construct CivArchive URL
|
||||
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
html_content = await response.text()
|
||||
|
||||
# Parse HTML to extract JSON data
|
||||
soup_parser = _require_beautifulsoup()
|
||||
soup = soup_parser(html_content, 'html.parser')
|
||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||
|
||||
if not script_tag:
|
||||
return None
|
||||
|
||||
# Parse JSON content
|
||||
json_data = json.loads(script_tag.string)
|
||||
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
|
||||
|
||||
if not model_data or 'version' not in model_data:
|
||||
return None
|
||||
|
||||
# Extract version data as base
|
||||
version = model_data['version'].copy()
|
||||
|
||||
# Restructure stats
|
||||
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
|
||||
version['stats'] = {
|
||||
'downloadCount': version.pop('downloadCount'),
|
||||
'ratingCount': version.pop('ratingCount'),
|
||||
'rating': version.pop('rating')
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords
|
||||
if 'trigger' in version:
|
||||
version['trainedWords'] = version.pop('trigger')
|
||||
|
||||
# Transform files data to expected format
|
||||
if 'files' in version:
|
||||
transformed_files = []
|
||||
for file_data in version['files']:
|
||||
# Find first available mirror (deletedAt is null)
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
# Create transformed file entry
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else None,
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
version['files'] = transformed_files
|
||||
|
||||
# Add model information
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('is_nsfw', False),
|
||||
'description': model_data.get('description'),
|
||||
'tags': model_data.get('tags', [])
|
||||
}
|
||||
|
||||
version['creator'] = {
|
||||
'username': model_data.get('username'),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Add source identifier
|
||||
version['source'] = 'civarchive'
|
||||
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
|
||||
|
||||
return version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
|
||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses SQLite database for metadata"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self._aiosqlite = _require_aiosqlite()
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash value from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
# Look up in model_files table to get model_id and version_id
|
||||
query = """
|
||||
SELECT model_id, version_id
|
||||
FROM model_files
|
||||
WHERE sha256 = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
cursor = await db.execute(query, (model_hash.upper(),))
|
||||
file_row = await cursor.fetchone()
|
||||
|
||||
if not file_row:
|
||||
return None, "Model not found"
|
||||
|
||||
# Get version details
|
||||
model_id = file_row['model_id']
|
||||
version_id = file_row['version_id']
|
||||
|
||||
# Build response in the same format as Civitai API
|
||||
result = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return result, None if result else "Error retrieving model data"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# First check if model exists
|
||||
model_query = "SELECT * FROM models WHERE id = ?"
|
||||
cursor = await db.execute(model_query, (model_id,))
|
||||
model_row = await cursor.fetchone()
|
||||
|
||||
if not model_row:
|
||||
return None
|
||||
|
||||
model_data = json.loads(model_row['data'])
|
||||
model_type = model_row['type']
|
||||
model_name = model_row['name']
|
||||
|
||||
# Get all versions for this model
|
||||
versions_query = """
|
||||
SELECT id, name, base_model, data, position, published_at
|
||||
FROM model_versions
|
||||
WHERE model_id = ?
|
||||
ORDER BY position ASC
|
||||
"""
|
||||
cursor = await db.execute(versions_query, (model_id,))
|
||||
version_rows = await cursor.fetchall()
|
||||
|
||||
if not version_rows:
|
||||
return {'modelVersions': [], 'type': model_type}
|
||||
|
||||
# Format versions similar to Civitai API
|
||||
model_versions = []
|
||||
for row in version_rows:
|
||||
version_data = json.loads(row['data'])
|
||||
# Add fields from the row to ensure we have the basic fields
|
||||
version_entry = {
|
||||
'id': row['id'],
|
||||
'modelId': int(model_id),
|
||||
'name': row['name'],
|
||||
'baseModel': row['base_model'],
|
||||
'model': {
|
||||
'name': model_row['name'],
|
||||
'type': model_type,
|
||||
},
|
||||
'source': 'archive_db'
|
||||
}
|
||||
# Update with any additional data
|
||||
version_entry.update(version_data)
|
||||
model_versions.append(version_entry)
|
||||
|
||||
return {
|
||||
'modelVersions': model_versions,
|
||||
'type': model_type,
|
||||
'name': model_name
|
||||
}
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata from SQLite database"""
|
||||
if not model_id and not version_id:
|
||||
return None
|
||||
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
cursor = await db.execute(version_query, (version_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None
|
||||
|
||||
model_id = version_row['model_id']
|
||||
|
||||
# Case 2: model_id is provided but version_id is not
|
||||
elif model_id is not None and version_id is None:
|
||||
# Find the latest version
|
||||
version_query = """
|
||||
SELECT id FROM model_versions
|
||||
WHERE model_id = ?
|
||||
ORDER BY position ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
cursor = await db.execute(version_query, (model_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None
|
||||
|
||||
version_id = version_row['id']
|
||||
|
||||
# Now we have both model_id and version_id, get the full data
|
||||
return await self._get_version_with_model_data(db, model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
# Get version details
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
cursor = await db.execute(version_query, (version_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None, "Model version not found"
|
||||
|
||||
model_id = version_row['model_id']
|
||||
|
||||
# Build complete version data with model info
|
||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return version_data, None
|
||||
|
||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||
"""Helper to build version data with model information"""
|
||||
# Get version details
|
||||
version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?"
|
||||
cursor = await db.execute(version_query, (version_id, model_id))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
if not version_row:
|
||||
return None
|
||||
|
||||
# Get model details
|
||||
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
|
||||
cursor = await db.execute(model_query, (model_id,))
|
||||
model_row = await cursor.fetchone()
|
||||
|
||||
if not model_row:
|
||||
return None
|
||||
|
||||
# Parse JSON data
|
||||
try:
|
||||
version_data = json.loads(version_row['data'])
|
||||
model_data = json.loads(model_row['data'])
|
||||
|
||||
# Build response
|
||||
result = {
|
||||
"id": int(version_id),
|
||||
"modelId": int(model_id),
|
||||
"name": version_row['name'],
|
||||
"baseModel": version_row['base_model'],
|
||||
"model": {
|
||||
"name": model_row['name'],
|
||||
"description": model_data.get("description"),
|
||||
"type": model_row['type'],
|
||||
"tags": model_data.get("tags", [])
|
||||
},
|
||||
"creator": {
|
||||
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
|
||||
"image": model_data.get("creator", {}).get("image")
|
||||
},
|
||||
"source": "archive_db"
|
||||
}
|
||||
|
||||
# Add any additional fields from version data
|
||||
result.update(version_data)
|
||||
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
"""Try providers in order, return first successful result."""
|
||||
def __init__(self, providers: list):
|
||||
self.providers = providers
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result, error = await provider.get_model_by_hash(model_hash)
|
||||
if result:
|
||||
return result, error
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_by_hash: {e}")
|
||||
continue
|
||||
return None, "Model not found"
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_model_versions(model_id)
|
||||
if result:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_versions: {e}")
|
||||
continue
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_model_version(model_id, version_id)
|
||||
if result:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_version: {e}")
|
||||
continue
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result, error = await provider.get_model_version_info(version_id)
|
||||
if result:
|
||||
return result, error
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_model_version_info: {e}")
|
||||
continue
|
||||
return None, "No provider could retrieve the data"
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of ModelMetadataProviderManager"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.providers = {}
|
||||
self.default_provider = None
|
||||
|
||||
def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False):
|
||||
"""Register a metadata provider"""
|
||||
self.providers[name] = provider
|
||||
if is_default or self.default_provider is None:
|
||||
self.default_provider = name
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by hash using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]:
|
||||
"""Get model versions using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_versions(model_id)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
|
||||
"""Get specific model version using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version(model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version info using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version_info(version_id)
|
||||
|
||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||
"""Get provider by name or default provider"""
|
||||
if provider_name and provider_name in self.providers:
|
||||
return self.providers[provider_name]
|
||||
|
||||
if self.default_provider is None:
|
||||
raise ValueError("No default provider set and no valid provider specified")
|
||||
|
||||
return self.providers[self.default_provider]
|
||||
196
py/services/model_query.py
Normal file
196
py/services/model_query.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||
|
||||
|
||||
class SettingsProvider(Protocol):
|
||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SortParams:
|
||||
"""Normalized representation of sorting instructions."""
|
||||
|
||||
key: str
|
||||
order: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FilterCriteria:
|
||||
"""Container for model list filtering options."""
|
||||
|
||||
folder: Optional[str] = None
|
||||
base_models: Optional[Sequence[str]] = None
|
||||
tags: Optional[Sequence[str]] = None
|
||||
favorites_only: bool = False
|
||||
search_options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModelCacheRepository:
|
||||
"""Adapter around scanner cache access and sort normalisation."""
|
||||
|
||||
def __init__(self, scanner) -> None:
|
||||
self._scanner = scanner
|
||||
|
||||
async def get_cache(self):
|
||||
"""Return the underlying cache instance from the scanner."""
|
||||
return await self._scanner.get_cached_data()
|
||||
|
||||
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
|
||||
"""Fetch cached data pre-sorted according to ``params``."""
|
||||
cache = await self.get_cache()
|
||||
return await cache.get_sorted_data(params.key, params.order)
|
||||
|
||||
@staticmethod
|
||||
def parse_sort(sort_by: str) -> SortParams:
|
||||
"""Parse an incoming sort string into key/order primitives."""
|
||||
if not sort_by:
|
||||
return SortParams(key="name", order="asc")
|
||||
|
||||
if ":" in sort_by:
|
||||
raw_key, raw_order = sort_by.split(":", 1)
|
||||
sort_key = raw_key.strip().lower() or "name"
|
||||
order = raw_order.strip().lower()
|
||||
else:
|
||||
sort_key = sort_by.strip().lower() or "name"
|
||||
order = "asc"
|
||||
|
||||
if order not in ("asc", "desc"):
|
||||
order = "asc"
|
||||
|
||||
return SortParams(key=sort_key, order=order)
|
||||
|
||||
|
||||
class ModelFilterSet:
|
||||
"""Applies common filtering rules to the model collection."""
|
||||
|
||||
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
||||
self._settings = settings
|
||||
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
||||
|
||||
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||
"""Return items that satisfy the provided criteria."""
|
||||
items = list(data)
|
||||
|
||||
if self._settings.get("show_only_sfw", False):
|
||||
threshold = self._nsfw_levels.get("R", 0)
|
||||
items = [
|
||||
item for item in items
|
||||
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||
]
|
||||
|
||||
if criteria.favorites_only:
|
||||
items = [item for item in items if item.get("favorite", False)]
|
||||
|
||||
folder = criteria.folder
|
||||
options = criteria.search_options or {}
|
||||
recursive = bool(options.get("recursive", True))
|
||||
if folder is not None:
|
||||
if recursive:
|
||||
if folder:
|
||||
folder_with_sep = f"{folder}/"
|
||||
items = [
|
||||
item for item in items
|
||||
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
||||
]
|
||||
else:
|
||||
items = [item for item in items if item.get("folder") == folder]
|
||||
|
||||
base_models = criteria.base_models or []
|
||||
if base_models:
|
||||
base_model_set = set(base_models)
|
||||
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||
|
||||
tags = criteria.tags or []
|
||||
if tags:
|
||||
tag_set = set(tags)
|
||||
items = [
|
||||
item for item in items
|
||||
if any(tag in tag_set for tag in item.get("tags", []))
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
class SearchStrategy:
|
||||
"""Encapsulates text and fuzzy matching behaviour for model queries."""
|
||||
|
||||
DEFAULT_OPTIONS: Dict[str, Any] = {
|
||||
"filename": True,
|
||||
"modelname": True,
|
||||
"tags": False,
|
||||
"recursive": True,
|
||||
"creator": False,
|
||||
}
|
||||
|
||||
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
||||
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
||||
|
||||
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge provided options with defaults without mutating input."""
|
||||
normalized = dict(self.DEFAULT_OPTIONS)
|
||||
if options:
|
||||
normalized.update(options)
|
||||
return normalized
|
||||
|
||||
def apply(
|
||||
self,
|
||||
data: Iterable[Dict[str, Any]],
|
||||
search_term: str,
|
||||
options: Dict[str, Any],
|
||||
fuzzy: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return items matching the search term using the configured strategy."""
|
||||
if not search_term:
|
||||
return list(data)
|
||||
|
||||
search_lower = search_term.lower()
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
for item in data:
|
||||
if options.get("filename", True):
|
||||
candidate = item.get("file_name", "")
|
||||
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("modelname", True):
|
||||
candidate = item.get("model_name", "")
|
||||
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("tags", False):
|
||||
tags = item.get("tags", []) or []
|
||||
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("creator", False):
|
||||
creator_username = ""
|
||||
civitai = item.get("civitai")
|
||||
if isinstance(civitai, dict):
|
||||
creator = civitai.get("creator")
|
||||
if isinstance(creator, dict):
|
||||
creator_username = creator.get("username", "")
|
||||
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
|
||||
candidate_lower = candidate.lower()
|
||||
if fuzzy:
|
||||
return self._fuzzy_match(candidate, search_term)
|
||||
return search_lower in candidate_lower
|
||||
@@ -5,30 +5,47 @@ import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Type, Set
|
||||
import msgpack # Add MessagePack import for efficient serialization
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
|
||||
from ..utils.file_utils import find_preview_file, get_preview_extension
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
from .model_lifecycle_service import delete_model_artifacts
|
||||
from .service_registry import ServiceRegistry
|
||||
from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define cache version to handle future format changes
|
||||
# Version history:
|
||||
# 1 - Initial version
|
||||
# 2 - Added duplicate_filenames and duplicate_hashes tracking
|
||||
# 3 - Added _excluded_models list to cache
|
||||
CACHE_VERSION = 3
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
_instances = {} # Dictionary to store instances by class
|
||||
_locks = {} # Dictionary to store locks by class
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Implement singleton pattern for each subclass"""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__new__(cls)
|
||||
return cls._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def _get_lock(cls):
|
||||
"""Get or create a lock for this class"""
|
||||
if cls not in cls._locks:
|
||||
cls._locks[cls] = asyncio.Lock()
|
||||
return cls._locks[cls]
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
lock = cls._get_lock()
|
||||
async with lock:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = cls()
|
||||
return cls._instances[cls]
|
||||
|
||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||
"""Initialize the scanner
|
||||
@@ -39,6 +56,10 @@ class ModelScanner:
|
||||
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||
hash_index: Hash index instance (optional)
|
||||
"""
|
||||
# Ensure initialization happens only once per instance
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_class = model_class
|
||||
self.file_extensions = file_extensions
|
||||
@@ -47,192 +68,15 @@ class ModelScanner:
|
||||
self._tags_count = {} # Dictionary to store tag counts
|
||||
self._is_initializing = False # Flag to track initialization state
|
||||
self._excluded_models = [] # List to track excluded models
|
||||
self._dirs_last_modified = {} # Track directory modification times
|
||||
self._initialized = True
|
||||
|
||||
# Register this service
|
||||
asyncio.create_task(self._register_service())
|
||||
|
||||
|
||||
async def _register_service(self):
|
||||
"""Register this instance with the ServiceRegistry"""
|
||||
service_name = f"{self.model_type}_scanner"
|
||||
await ServiceRegistry.register_service(service_name, self)
|
||||
|
||||
def _get_cache_file_path(self) -> Optional[str]:
|
||||
"""Get the path to the cache file"""
|
||||
# Get the directory where this module is located
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
# Create a cache directory within the project if it doesn't exist
|
||||
cache_dir = os.path.join(current_dir, "cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create filename based on model type
|
||||
cache_filename = f"lm_{self.model_type}_cache.msgpack"
|
||||
return os.path.join(cache_dir, cache_filename)
|
||||
|
||||
def _prepare_for_msgpack(self, data):
|
||||
"""Preprocess data to accommodate MessagePack serialization limitations
|
||||
|
||||
Converts integers exceeding safe range to strings
|
||||
|
||||
Args:
|
||||
data: Any type of data structure
|
||||
|
||||
Returns:
|
||||
Preprocessed data structure with large integers converted to strings
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: self._prepare_for_msgpack(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self._prepare_for_msgpack(item) for item in data]
|
||||
elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991):
|
||||
# Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings
|
||||
return str(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
async def _save_cache_to_disk(self) -> bool:
|
||||
"""Save cache data to disk using MessagePack"""
|
||||
if self._cache is None or not self._cache.raw_data:
|
||||
logger.debug(f"No {self.model_type} cache data to save")
|
||||
return False
|
||||
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path:
|
||||
logger.warning(f"Cannot determine {self.model_type} cache file location")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create cache data structure
|
||||
cache_data = {
|
||||
"version": CACHE_VERSION,
|
||||
"timestamp": time.time(),
|
||||
"model_type": self.model_type,
|
||||
"raw_data": self._cache.raw_data,
|
||||
"hash_index": {
|
||||
"hash_to_path": self._hash_index._hash_to_path,
|
||||
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
|
||||
"duplicate_hashes": self._hash_index._duplicate_hashes,
|
||||
"duplicate_filenames": self._hash_index._duplicate_filenames
|
||||
},
|
||||
"tags_count": self._tags_count,
|
||||
"dirs_last_modified": self._get_dirs_last_modified(),
|
||||
"excluded_models": self._excluded_models # Add excluded_models to cache data
|
||||
}
|
||||
|
||||
# Preprocess data to handle large integers
|
||||
processed_cache_data = self._prepare_for_msgpack(cache_data)
|
||||
|
||||
# Write to temporary file first (atomic operation)
|
||||
temp_path = f"{cache_path}.tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
msgpack.pack(processed_cache_data, f)
|
||||
|
||||
# Replace the old file with the new one
|
||||
if os.path.exists(cache_path):
|
||||
os.replace(temp_path, cache_path)
|
||||
else:
|
||||
os.rename(temp_path, cache_path)
|
||||
|
||||
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
|
||||
logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
|
||||
# Try to clean up temp file if it exists
|
||||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_dirs_last_modified(self) -> Dict[str, float]:
|
||||
"""Get last modified time for all model directories"""
|
||||
dirs_info = {}
|
||||
for root in self.get_model_roots():
|
||||
if os.path.exists(root):
|
||||
dirs_info[root] = os.path.getmtime(root)
|
||||
# Also check immediate subdirectories for changes
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
if entry.is_dir(follow_symlinks=True):
|
||||
dirs_info[entry.path] = entry.stat().st_mtime
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting directory info for {root}: {e}")
|
||||
return dirs_info
|
||||
|
||||
def _is_cache_valid(self, cache_data: Dict) -> bool:
|
||||
"""Validate if the loaded cache is still valid"""
|
||||
if not cache_data or cache_data.get("version") != CACHE_VERSION:
|
||||
logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}")
|
||||
return False
|
||||
|
||||
if cache_data.get("model_type") != self.model_type:
|
||||
logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}")
|
||||
return False
|
||||
|
||||
# Check if directories have changed
|
||||
# stored_dirs = cache_data.get("dirs_last_modified", {})
|
||||
# current_dirs = self._get_dirs_last_modified()
|
||||
|
||||
# If directory structure has changed, cache is invalid
|
||||
# if set(stored_dirs.keys()) != set(current_dirs.keys()):
|
||||
# logger.info(f"Cache invalid - directory structure changed. Stored: {set(stored_dirs.keys())}, Current: {set(current_dirs.keys())}")
|
||||
# return False
|
||||
|
||||
# Remove the modification time check to make cache validation less strict
|
||||
# This allows the cache to be valid even when files have changed
|
||||
# Users can explicitly refresh the cache when needed
|
||||
|
||||
return True
|
||||
|
||||
async def _load_cache_from_disk(self) -> bool:
|
||||
"""Load cache data from disk using MessagePack"""
|
||||
start_time = time.time()
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path or not os.path.exists(cache_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
cache_data = msgpack.unpack(f)
|
||||
|
||||
# Validate cache data
|
||||
if not self._is_cache_valid(cache_data):
|
||||
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
|
||||
return False
|
||||
|
||||
# Load data into memory
|
||||
self._cache = ModelCache(
|
||||
raw_data=cache_data["raw_data"],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Load hash index
|
||||
hash_index_data = cache_data.get("hash_index", {})
|
||||
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
|
||||
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
|
||||
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
|
||||
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
|
||||
|
||||
# Load tags count
|
||||
self._tags_count = cache_data.get("tags_count", {})
|
||||
|
||||
# Load excluded models
|
||||
self._excluded_models = cache_data.get("excluded_models", [])
|
||||
|
||||
# Resort the cache
|
||||
await self._cache.resort()
|
||||
|
||||
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
|
||||
return False
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
@@ -241,8 +85,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -260,21 +102,6 @@ class ModelScanner:
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
|
||||
if cache_loaded:
|
||||
# Cache loaded successfully, broadcast complete message
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
'progress': 100,
|
||||
'status': 'complete',
|
||||
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
self._is_initializing = False
|
||||
return
|
||||
|
||||
# If cache loading failed, proceed with full scan
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -321,9 +148,6 @@ class ModelScanner:
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||
|
||||
# Save the cache to disk after initialization
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
# Send completion message
|
||||
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -479,6 +303,13 @@ class ModelScanner:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache.raw_data = raw_data
|
||||
loop.run_until_complete(self._cache.resort())
|
||||
@@ -498,40 +329,21 @@ class ModelScanner:
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to refresh the cache
|
||||
rebuild_cache: Whether to completely rebuild the cache by reloading from disk first
|
||||
rebuild_cache: Whether to completely rebuild the cache
|
||||
"""
|
||||
# If cache is not initialized, return an empty cache
|
||||
# Actual initialization should be done via initialize_in_background
|
||||
if self._cache is None and not force_refresh:
|
||||
return ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
# If rebuild_cache is True, try to reload from disk before reconciliation
|
||||
if rebuild_cache:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...")
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
if cache_loaded:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk")
|
||||
else:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild")
|
||||
# If loading from disk failed, do a complete rebuild and save to disk
|
||||
await self._initialize_cache()
|
||||
await self._save_cache_to_disk()
|
||||
return self._cache
|
||||
|
||||
if self._cache is None:
|
||||
# For initial creation, do a full initialization
|
||||
await self._initialize_cache()
|
||||
# Save the newly built cache
|
||||
await self._save_cache_to_disk()
|
||||
else:
|
||||
# For subsequent refreshes, use fast reconciliation
|
||||
await self._reconcile_cache()
|
||||
|
||||
return self._cache
|
||||
@@ -563,11 +375,16 @@ class ModelScanner:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache = ModelCache(
|
||||
raw_data=raw_data,
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -581,8 +398,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
finally:
|
||||
@@ -659,26 +474,33 @@ class ModelScanner:
|
||||
batch = new_files[i:i+batch_size]
|
||||
for path in batch:
|
||||
try:
|
||||
model_data = await self.scan_single_model(path)
|
||||
if model_data:
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
|
||||
# Update hash index if available
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Update tags count
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
total_added += 1
|
||||
# Find the appropriate root path for this file
|
||||
root_path = None
|
||||
for potential_root in self.get_model_roots():
|
||||
if path.startswith(potential_root):
|
||||
root_path = potential_root
|
||||
break
|
||||
|
||||
if root_path:
|
||||
model_data = await self._process_model_file(path, root_path)
|
||||
if model_data:
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
|
||||
# Update hash index if available
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Update tags count
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
total_added += 1
|
||||
else:
|
||||
logger.error(f"Could not determine root path for {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding {path} to cache: {e}")
|
||||
|
||||
# Yield control after each batch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Find missing files (in cache but not in filesystem)
|
||||
missing_files = cached_paths - found_paths
|
||||
@@ -717,50 +539,78 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||
finally:
|
||||
self._is_initializing = False # Unset flag
|
||||
|
||||
# These methods should be implemented in child classes
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
raise NotImplementedError("Subclasses must implement scan_all_models")
|
||||
all_models = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for model_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(model_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
models = await task
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_models
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for model files"""
|
||||
models = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(file_path, original_root)
|
||||
# Only add to models if result is not None (skip corrupted metadata)
|
||||
if result:
|
||||
models.append(result)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return models
|
||||
|
||||
def is_initializing(self) -> bool:
|
||||
"""Check if the scanner is currently initializing"""
|
||||
return self._is_initializing
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
raise NotImplementedError("Subclasses must implement get_model_roots")
|
||||
|
||||
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
|
||||
"""Scan a single model file and return its metadata"""
|
||||
try:
|
||||
if not os.path.exists(os.path.realpath(file_path)):
|
||||
return None
|
||||
|
||||
# Get basic file info
|
||||
metadata = await self._get_file_info(file_path)
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
folder = self._calculate_folder(file_path)
|
||||
|
||||
# Ensure folder field exists
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['folder'] = folder or ''
|
||||
|
||||
return metadata_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
|
||||
async def _create_default_metadata(self, file_path: str) -> Optional[BaseModelMetadata]:
|
||||
"""Get model file info and metadata (extensible for different model types)"""
|
||||
return await get_file_info(file_path, self.model_class)
|
||||
return await MetadataManager.create_default_metadata(file_path, self.model_class)
|
||||
|
||||
def _calculate_folder(self, file_path: str) -> str:
|
||||
"""Calculate the folder path for a model file"""
|
||||
@@ -770,10 +620,18 @@ class ModelScanner:
|
||||
return os.path.dirname(rel_path).replace(os.path.sep, '/')
|
||||
return ''
|
||||
|
||||
# Common methods shared between scanners
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
"""Hook for subclasses: adjust metadata during scanning"""
|
||||
return metadata
|
||||
|
||||
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
||||
"""Process a single model file and return its metadata"""
|
||||
metadata = await load_metadata(file_path, self.model_class)
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
|
||||
if should_skip:
|
||||
# Metadata file exists but cannot be parsed - skip this model
|
||||
logger.warning(f"Skipping model {file_path} due to corrupted metadata file")
|
||||
return None
|
||||
|
||||
if metadata is None:
|
||||
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||
@@ -789,7 +647,7 @@ class ModelScanner:
|
||||
|
||||
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
|
||||
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
|
||||
await save_metadata(file_path, metadata)
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||
@@ -816,13 +674,16 @@ class ModelScanner:
|
||||
metadata.modelDescription = version_info['model']['description']
|
||||
|
||||
# Save the updated metadata
|
||||
await save_metadata(file_path, metadata)
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
logger.debug(f"Updated metadata with civitai info for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
|
||||
|
||||
if metadata is None:
|
||||
metadata = await self._get_file_info(file_path)
|
||||
metadata = await self._create_default_metadata(file_path)
|
||||
|
||||
# Hook: allow subclasses to adjust metadata
|
||||
metadata = self.adjust_metadata(metadata, file_path, root_path)
|
||||
|
||||
model_data = metadata.to_dict()
|
||||
|
||||
@@ -830,113 +691,23 @@ class ModelScanner:
|
||||
if model_data.get('exclude', False):
|
||||
self._excluded_models.append(model_data['file_path'])
|
||||
return None
|
||||
|
||||
await self._fetch_missing_metadata(file_path, model_data)
|
||||
|
||||
# Check for duplicate filename before adding to hash index
|
||||
filename = os.path.splitext(os.path.basename(file_path))[0]
|
||||
existing_hash = self._hash_index.get_hash_by_filename(filename)
|
||||
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
|
||||
existing_path = self._hash_index.get_path(existing_hash)
|
||||
if existing_path and existing_path != file_path:
|
||||
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
model_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
return model_data
|
||||
|
||||
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
|
||||
"""Fetch missing description and tags from Civitai if needed"""
|
||||
try:
|
||||
if model_data.get('civitai_deleted', False):
|
||||
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
|
||||
return
|
||||
|
||||
needs_metadata_update = False
|
||||
model_id = None
|
||||
|
||||
if model_data.get('civitai'):
|
||||
model_id = model_data['civitai'].get('modelId')
|
||||
|
||||
if model_id:
|
||||
model_id = str(model_id)
|
||||
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
|
||||
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
|
||||
# TODO: not for now, but later we should check if the creator is missing
|
||||
# creator_missing = not model_data.get('civitai', {}).get('creator')
|
||||
creator_missing = False
|
||||
needs_metadata_update = tags_missing or desc_missing or creator_missing
|
||||
|
||||
if needs_metadata_update and model_id:
|
||||
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
client = CivitaiClient()
|
||||
|
||||
model_metadata, status_code = await client.get_model_metadata(model_id)
|
||||
await client.close()
|
||||
|
||||
if status_code == 404:
|
||||
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
|
||||
model_data['civitai_deleted'] = True
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
elif model_metadata:
|
||||
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
|
||||
|
||||
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
|
||||
model_data['tags'] = model_metadata['tags']
|
||||
|
||||
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
|
||||
model_data['modelDescription'] = model_metadata['description']
|
||||
|
||||
model_data['civitai']['creator'] = model_metadata['creator']
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Base implementation for directory scanning"""
|
||||
models = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, models)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return models
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
models_list.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||
"""Add a model to the cache and save to disk
|
||||
"""Add a model to the cache
|
||||
|
||||
Args:
|
||||
metadata_dict: The model metadata dictionary
|
||||
@@ -965,16 +736,21 @@ class ModelScanner:
|
||||
|
||||
# Update the hash index
|
||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Save to disk
|
||||
await self._save_cache_to_disk()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding model to cache: {e}")
|
||||
return False
|
||||
|
||||
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||
"""Move a model and its associated files to a new location"""
|
||||
async def move_model(self, source_path: str, target_path: str) -> Optional[str]:
|
||||
"""Move a model and its associated files to a new location
|
||||
|
||||
Args:
|
||||
source_path: Original file path
|
||||
target_path: Target directory path
|
||||
|
||||
Returns:
|
||||
Optional[str]: New file path if successful, None if failed
|
||||
"""
|
||||
try:
|
||||
source_path = source_path.replace(os.sep, '/')
|
||||
target_path = target_path.replace(os.sep, '/')
|
||||
@@ -983,38 +759,32 @@ class ModelScanner:
|
||||
|
||||
if not file_ext or file_ext.lower() not in self.file_extensions:
|
||||
logger.error(f"Invalid file extension for model: {file_ext}")
|
||||
return False
|
||||
return None
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||
source_dir = os.path.dirname(source_path)
|
||||
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
|
||||
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
|
||||
def get_source_hash():
|
||||
return self.get_hash_by_path(source_path)
|
||||
|
||||
# Check for filename conflicts and auto-rename if necessary
|
||||
from ..utils.models import BaseModelMetadata
|
||||
final_filename = BaseModelMetadata.generate_unique_filename(
|
||||
target_path, base_name, file_ext, get_source_hash
|
||||
)
|
||||
|
||||
target_file = os.path.join(target_path, final_filename).replace(os.sep, '/')
|
||||
final_base_name = os.path.splitext(final_filename)[0]
|
||||
|
||||
# Log if filename was changed due to conflict
|
||||
if final_filename != f"{base_name}{file_ext}":
|
||||
logger.info(f"Renamed {base_name}{file_ext} to {final_filename} to avoid filename conflict")
|
||||
|
||||
real_source = os.path.realpath(source_path)
|
||||
real_target = os.path.realpath(target_file)
|
||||
|
||||
file_size = os.path.getsize(real_source)
|
||||
|
||||
# Get the appropriate file monitor through ServiceRegistry
|
||||
if self.model_type == "lora":
|
||||
monitor = await ServiceRegistry.get_lora_monitor()
|
||||
elif self.model_type == "checkpoint":
|
||||
monitor = await ServiceRegistry.get_checkpoint_monitor()
|
||||
else:
|
||||
monitor = None
|
||||
|
||||
if monitor:
|
||||
monitor.handler.add_ignore_path(
|
||||
real_source,
|
||||
file_size
|
||||
)
|
||||
monitor.handler.add_ignore_path(
|
||||
real_target,
|
||||
file_size
|
||||
)
|
||||
|
||||
shutil.move(real_source, real_target)
|
||||
|
||||
# Move all associated files with the same base name
|
||||
@@ -1027,12 +797,17 @@ class ModelScanner:
|
||||
for file in os.listdir(source_dir):
|
||||
if file.startswith(base_name + ".") and file != os.path.basename(source_path):
|
||||
source_file_path = os.path.join(source_dir, file)
|
||||
# Generate new filename with the same base name as the model file
|
||||
file_suffix = file[len(base_name):] # Get the part after base_name (e.g., ".metadata.json", ".preview.png")
|
||||
new_associated_filename = f"{final_base_name}{file_suffix}"
|
||||
target_associated_path = os.path.join(target_path, new_associated_filename)
|
||||
|
||||
# Store metadata file path for special handling
|
||||
if file == f"{base_name}.metadata.json":
|
||||
source_metadata = source_file_path
|
||||
moved_metadata_path = os.path.join(target_path, file)
|
||||
moved_metadata_path = target_associated_path
|
||||
else:
|
||||
files_to_move.append((source_file_path, os.path.join(target_path, file)))
|
||||
files_to_move.append((source_file_path, target_associated_path))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing files in {source_dir}: {e}")
|
||||
|
||||
@@ -1054,11 +829,11 @@ class ModelScanner:
|
||||
|
||||
await self.update_single_model_cache(source_path, target_file, metadata)
|
||||
|
||||
return True
|
||||
return target_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return False
|
||||
return None
|
||||
|
||||
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
|
||||
"""Update file paths in metadata file"""
|
||||
@@ -1067,16 +842,18 @@ class ModelScanner:
|
||||
metadata = json.load(f)
|
||||
|
||||
metadata['file_path'] = model_path.replace(os.sep, '/')
|
||||
# Update file_name to match the new filename
|
||||
metadata['file_name'] = os.path.splitext(os.path.basename(model_path))[0]
|
||||
|
||||
if 'preview_url' in metadata:
|
||||
if 'preview_url' in metadata and metadata['preview_url']:
|
||||
preview_dir = os.path.dirname(model_path)
|
||||
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
||||
preview_ext = os.path.splitext(metadata['preview_url'])[1]
|
||||
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
||||
# Update preview filename to match the new base name
|
||||
new_base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
preview_ext = get_preview_extension(metadata['preview_url'])
|
||||
new_preview_path = os.path.join(preview_dir, f"{new_base_name}{preview_ext}")
|
||||
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
await MetadataManager.save_metadata(metadata_path, metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -1128,9 +905,6 @@ class ModelScanner:
|
||||
|
||||
await cache.resort()
|
||||
|
||||
# Save the updated cache
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
@@ -1143,8 +917,16 @@ class ModelScanner:
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self._hash_index.get_hash(file_path)
|
||||
if self._cache is None or not self._cache.raw_data:
|
||||
return None
|
||||
|
||||
# Iterate through cache data to find matching file path
|
||||
for model_data in self._cache.raw_data:
|
||||
if model_data.get('file_path') == file_path:
|
||||
return model_data.get('sha256')
|
||||
|
||||
return None
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a model by its filename without path"""
|
||||
return self._hash_index.get_hash_by_filename(filename)
|
||||
@@ -1210,12 +992,13 @@ class ModelScanner:
|
||||
"""Get list of excluded model file paths"""
|
||||
return self._excluded_models.copy()
|
||||
|
||||
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
||||
async def update_preview_in_cache(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||
"""Update preview URL in cache for a specific lora
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
preview_nsfw_level: The NSFW level of the preview
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
|
||||
@@ -1223,11 +1006,7 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
return False
|
||||
|
||||
updated = await self._cache.update_preview_url(file_path, preview_url)
|
||||
if updated:
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
return updated
|
||||
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
|
||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||
"""Delete multiple models and update cache in a batch operation
|
||||
@@ -1246,9 +1025,6 @@ class ModelScanner:
|
||||
'results': []
|
||||
}
|
||||
|
||||
# Get the file monitor
|
||||
file_monitor = getattr(self, 'file_monitor', None)
|
||||
|
||||
# Keep track of success and failures
|
||||
results = []
|
||||
total_deleted = 0
|
||||
@@ -1265,12 +1041,9 @@ class ModelScanner:
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# Delete all associated files for the model
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
||||
target_dir,
|
||||
file_name,
|
||||
file_monitor
|
||||
deleted_files = await delete_model_artifacts(
|
||||
target_dir,
|
||||
file_name
|
||||
)
|
||||
|
||||
if deleted_files:
|
||||
@@ -1352,7 +1125,7 @@ class ModelScanner:
|
||||
hash_val = model.get('sha256', '').lower()
|
||||
|
||||
# Remove from hash index
|
||||
self._hash_index.remove_by_path(file_path)
|
||||
self._hash_index.remove_by_path(file_path, hash_val)
|
||||
|
||||
# Check and clean up duplicates
|
||||
self._cleanup_duplicates_after_removal(hash_val, file_name)
|
||||
@@ -1363,9 +1136,6 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -1391,3 +1161,56 @@ class ModelScanner:
|
||||
if file_name in self._hash_index._duplicate_filenames:
|
||||
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
||||
del self._hash_index._duplicate_filenames[file_name]
|
||||
|
||||
async def check_model_version_exists(self, model_version_id: int) -> bool:
|
||||
"""Check if a specific model version exists in the cache
|
||||
|
||||
Args:
|
||||
model_version_id: Civitai model version ID
|
||||
|
||||
Returns:
|
||||
bool: True if the model version exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return False
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking model version existence: {e}")
|
||||
return False
|
||||
|
||||
async def get_model_versions_by_id(self, model_id: int) -> List[Dict]:
|
||||
"""Get all versions of a model by its ID
|
||||
|
||||
Args:
|
||||
model_id: Civitai model ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of version information dictionaries
|
||||
"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return []
|
||||
|
||||
versions = []
|
||||
for item in cache.raw_data:
|
||||
if (item.get('civitai') and
|
||||
item['civitai'].get('modelId') == model_id and
|
||||
item['civitai'].get('id')):
|
||||
versions.append({
|
||||
'versionId': item['civitai'].get('id'),
|
||||
'name': item['civitai'].get('name'),
|
||||
'fileName': item.get('file_name', '')
|
||||
})
|
||||
|
||||
return versions
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model versions: {e}")
|
||||
return []
|
||||
|
||||
142
py/services/model_service_factory.py
Normal file
142
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Dict, Type, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelServiceFactory:
|
||||
"""Factory for managing model services and routes"""
|
||||
|
||||
_services: Dict[str, Type] = {}
|
||||
_routes: Dict[str, Type] = {}
|
||||
_initialized_services: Dict[str, Any] = {}
|
||||
_initialized_routes: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||
"""Register a new model type with its service and route classes
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||
service_class: The service class for this model type
|
||||
route_class: The route class for this model type
|
||||
"""
|
||||
cls._services[model_type] = service_class
|
||||
cls._routes[model_type] = route_class
|
||||
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_service_class(cls, model_type: str) -> Type:
|
||||
"""Get service class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The service class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._services:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._services[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_class(cls, model_type: str) -> Type:
|
||||
"""Get route class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._routes:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_instance(cls, model_type: str):
|
||||
"""Get or create route instance for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route instance for the model type
|
||||
"""
|
||||
if model_type not in cls._initialized_routes:
|
||||
route_class = cls.get_route_class(model_type)
|
||||
cls._initialized_routes[model_type] = route_class()
|
||||
return cls._initialized_routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def setup_all_routes(cls, app):
|
||||
"""Setup routes for all registered model types
|
||||
|
||||
Args:
|
||||
app: The aiohttp application instance
|
||||
"""
|
||||
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||
|
||||
for model_type in cls._services.keys():
|
||||
try:
|
||||
routes_instance = cls.get_route_instance(model_type)
|
||||
routes_instance.setup_routes(app)
|
||||
logger.info(f"Successfully set up routes for {model_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def get_registered_types(cls) -> list:
|
||||
"""Get list of all registered model types
|
||||
|
||||
Returns:
|
||||
List of registered model type identifiers
|
||||
"""
|
||||
return list(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, model_type: str) -> bool:
|
||||
"""Check if a model type is registered
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
True if the model type is registered, False otherwise
|
||||
"""
|
||||
return model_type in cls._services
|
||||
|
||||
@classmethod
|
||||
def clear_registrations(cls):
|
||||
"""Clear all registrations - mainly for testing purposes"""
|
||||
cls._services.clear()
|
||||
cls._routes.clear()
|
||||
cls._initialized_services.clear()
|
||||
cls._initialized_routes.clear()
|
||||
logger.info("Cleared all model type registrations")
|
||||
|
||||
|
||||
def register_default_model_types():
|
||||
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..routes.lora_routes import LoraRoutes
|
||||
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||
from ..routes.embedding_routes import EmbeddingRoutes
|
||||
|
||||
# Register LoRA model type
|
||||
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||
|
||||
# Register Checkpoint model type
|
||||
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||
|
||||
# Register Embedding model type
|
||||
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
|
||||
|
||||
logger.info("Registered default model types: lora, checkpoint, embedding")
|
||||
168
py/services/preview_asset_service.py
Normal file
168
py/services/preview_asset_service.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Service for processing preview assets for models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewAssetService:
|
||||
"""Manage fetching and persisting preview assets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
downloader_factory: Callable[[], Awaitable],
|
||||
exif_utils,
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._downloader_factory = downloader_factory
|
||||
self._exif_utils = exif_utils
|
||||
|
||||
async def ensure_preview_for_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, object],
|
||||
images: Sequence[Dict[str, object]] | None,
|
||||
) -> None:
|
||||
"""Ensure preview assets exist for the supplied metadata entry."""
|
||||
|
||||
if local_metadata.get("preview_url") and os.path.exists(
|
||||
str(local_metadata["preview_url"])
|
||||
):
|
||||
return
|
||||
|
||||
if not images:
|
||||
return
|
||||
|
||||
first_preview = images[0]
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_dir = os.path.dirname(metadata_path)
|
||||
is_video = first_preview.get("type") == "video"
|
||||
|
||||
if is_video:
|
||||
extension = ".mp4"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(
|
||||
first_preview["url"], preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
else:
|
||||
extension = ".webp"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
first_preview["url"], use_auth=False
|
||||
)
|
||||
if not success:
|
||||
return
|
||||
|
||||
try:
|
||||
optimized_data, _ = self._exif_utils.optimize_image(
|
||||
image_data=content,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(optimized_data)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error optimizing preview image: %s", exc)
|
||||
try:
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(content)
|
||||
except Exception as save_exc:
|
||||
logger.error("Error saving preview image: %s", save_exc)
|
||||
return
|
||||
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
|
||||
async def replace_preview(
|
||||
self,
|
||||
*,
|
||||
model_path: str,
|
||||
preview_data: bytes,
|
||||
content_type: str,
|
||||
original_filename: Optional[str],
|
||||
nsfw_level: int,
|
||||
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
) -> Dict[str, object]:
|
||||
"""Replace an existing preview asset for a model."""
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
folder = os.path.dirname(model_path)
|
||||
|
||||
extension, optimized_data = await self._convert_preview(
|
||||
preview_data, content_type, original_filename
|
||||
)
|
||||
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
existing_preview = os.path.join(folder, base_name + ext)
|
||||
if os.path.exists(existing_preview):
|
||||
try:
|
||||
os.remove(existing_preview)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.warning(
|
||||
"Failed to delete existing preview %s: %s", existing_preview, exc
|
||||
)
|
||||
|
||||
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(optimized_data)
|
||||
|
||||
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
metadata["preview_url"] = preview_path
|
||||
metadata["preview_nsfw_level"] = nsfw_level
|
||||
await self._metadata_manager.save_metadata(model_path, metadata)
|
||||
|
||||
await update_preview_in_cache(model_path, preview_path, nsfw_level)
|
||||
|
||||
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
|
||||
|
||||
async def _convert_preview(
|
||||
self, data: bytes, content_type: str, original_filename: Optional[str]
|
||||
) -> tuple[str, bytes]:
|
||||
"""Convert preview bytes to the persisted representation."""
|
||||
|
||||
if content_type.startswith("video/"):
|
||||
extension = self._resolve_video_extension(content_type, original_filename)
|
||||
return extension, data
|
||||
|
||||
original_ext = (original_filename or "").lower()
|
||||
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
|
||||
return ".gif", data
|
||||
|
||||
optimized_data, _ = self._exif_utils.optimize_image(
|
||||
image_data=data,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
return ".webp", optimized_data
|
||||
|
||||
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
|
||||
"""Infer the best extension for a video preview."""
|
||||
|
||||
if original_filename:
|
||||
extension = os.path.splitext(original_filename)[1].lower()
|
||||
if extension in {".mp4", ".webm", ".mov", ".avi"}:
|
||||
return extension
|
||||
|
||||
if "webm" in content_type:
|
||||
return ".webm"
|
||||
return ".mp4"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import Iterable, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
@@ -10,77 +10,115 @@ class RecipeCache:
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x.get('title', '').lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('created_date', 'file_path'),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool:
|
||||
self._resort_locked(name_only=name_only)
|
||||
|
||||
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool:
|
||||
"""Update metadata for a specific recipe in all cached data
|
||||
|
||||
|
||||
Args:
|
||||
recipe_id: The ID of the recipe to update
|
||||
metadata: The new metadata
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the recipe wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
for item in self.raw_data:
|
||||
if str(item.get('id')) == str(recipe_id):
|
||||
item.update(metadata)
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return True
|
||||
return False # Recipe not found
|
||||
|
||||
async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None:
|
||||
"""Add a new recipe to the cache."""
|
||||
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item.get('id') == recipe_id:
|
||||
item.update(metadata)
|
||||
break
|
||||
else:
|
||||
return False # Recipe not found
|
||||
|
||||
# Resort to reflect changes
|
||||
await self.resort()
|
||||
return True
|
||||
|
||||
async def add_recipe(self, recipe_data: Dict) -> None:
|
||||
"""Add a new recipe to the cache
|
||||
|
||||
Args:
|
||||
recipe_data: The recipe data to add
|
||||
"""
|
||||
async with self._lock:
|
||||
self.raw_data.append(recipe_data)
|
||||
await self.resort()
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
|
||||
async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]:
|
||||
"""Remove a recipe from the cache by ID.
|
||||
|
||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||
"""Remove a recipe from the cache by ID
|
||||
|
||||
Args:
|
||||
recipe_id: The ID of the recipe to remove
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the recipe was found and removed, False otherwise
|
||||
The removed recipe data if found, otherwise ``None``.
|
||||
"""
|
||||
# Find the recipe in raw_data
|
||||
recipe_index = next((i for i, recipe in enumerate(self.raw_data)
|
||||
if recipe.get('id') == recipe_id), None)
|
||||
|
||||
if recipe_index is None:
|
||||
return False
|
||||
|
||||
# Remove from raw_data
|
||||
self.raw_data.pop(recipe_index)
|
||||
|
||||
# Resort to update sorted lists
|
||||
await self.resort()
|
||||
|
||||
return True
|
||||
|
||||
async with self._lock:
|
||||
for index, recipe in enumerate(self.raw_data):
|
||||
if str(recipe.get('id')) == str(recipe_id):
|
||||
removed = self.raw_data.pop(index)
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return removed
|
||||
return None
|
||||
|
||||
async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]:
|
||||
"""Remove multiple recipes from the cache."""
|
||||
|
||||
id_set = {str(recipe_id) for recipe_id in recipe_ids}
|
||||
if not id_set:
|
||||
return []
|
||||
|
||||
async with self._lock:
|
||||
removed = [item for item in self.raw_data if str(item.get('id')) in id_set]
|
||||
if not removed:
|
||||
return []
|
||||
|
||||
self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set]
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return removed
|
||||
|
||||
async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool:
|
||||
"""Replace cached data for a recipe."""
|
||||
|
||||
async with self._lock:
|
||||
for index, recipe in enumerate(self.raw_data):
|
||||
if str(recipe.get('id')) == str(recipe_id):
|
||||
self.raw_data[index] = new_data
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_recipe(self, recipe_id: str) -> Optional[Dict]:
|
||||
"""Return a shallow copy of a cached recipe."""
|
||||
|
||||
async with self._lock:
|
||||
for recipe in self.raw_data:
|
||||
if str(recipe.get('id')) == str(recipe_id):
|
||||
return dict(recipe)
|
||||
return None
|
||||
|
||||
async def snapshot(self) -> List[Dict]:
|
||||
"""Return a copy of all cached recipes."""
|
||||
|
||||
async with self._lock:
|
||||
return [dict(item) for item in self.raw_data]
|
||||
|
||||
def _resort_locked(self, *, name_only: bool = False) -> None:
|
||||
"""Sort cached views. Caller must hold ``_lock``."""
|
||||
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x.get('title', '').lower()
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('created_date', 'file_path'),
|
||||
reverse=True
|
||||
)
|
||||
@@ -3,12 +3,14 @@ import logging
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from ..config import config
|
||||
from .recipe_cache import RecipeCache
|
||||
from .service_registry import ServiceRegistry
|
||||
from .lora_scanner import LoraScanner
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from .recipes.errors import RecipeNotFoundError
|
||||
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
|
||||
from natsort import natsorted
|
||||
import sys
|
||||
|
||||
@@ -45,6 +47,8 @@ class RecipeScanner:
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
self._initialization_task: Optional[asyncio.Task] = None
|
||||
self._is_initializing = False
|
||||
self._mutation_lock = asyncio.Lock()
|
||||
self._resort_tasks: Set[asyncio.Task] = set()
|
||||
if lora_scanner:
|
||||
self._lora_scanner = lora_scanner
|
||||
self._initialized = True
|
||||
@@ -190,6 +194,22 @@ class RecipeScanner:
|
||||
# Clean up the event loop
|
||||
loop.close()
|
||||
|
||||
def _schedule_resort(self, *, name_only: bool = False) -> None:
|
||||
"""Schedule a background resort of the recipe cache."""
|
||||
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
async def _resort_wrapper() -> None:
|
||||
try:
|
||||
await self._cache.resort(name_only=name_only)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True)
|
||||
|
||||
task = asyncio.create_task(_resort_wrapper())
|
||||
self._resort_tasks.add(task)
|
||||
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
|
||||
|
||||
@property
|
||||
def recipes_dir(self) -> str:
|
||||
"""Get path to recipes directory"""
|
||||
@@ -254,7 +274,45 @@ class RecipeScanner:
|
||||
|
||||
# Return the cache (may be empty or partially initialized)
|
||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||
|
||||
|
||||
async def refresh_cache(self, force: bool = False) -> RecipeCache:
|
||||
"""Public helper to refresh or return the recipe cache."""
|
||||
|
||||
return await self.get_cached_data(force_refresh=force)
|
||||
|
||||
async def add_recipe(self, recipe_data: Dict[str, Any]) -> None:
|
||||
"""Add a recipe to the in-memory cache."""
|
||||
|
||||
if not recipe_data:
|
||||
return
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
await cache.add_recipe(recipe_data, resort=False)
|
||||
self._schedule_resort()
|
||||
|
||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||
"""Remove a recipe from the cache by ID."""
|
||||
|
||||
if not recipe_id:
|
||||
return False
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
removed = await cache.remove_recipe(recipe_id, resort=False)
|
||||
if removed is None:
|
||||
return False
|
||||
|
||||
self._schedule_resort()
|
||||
return True
|
||||
|
||||
async def bulk_remove(self, recipe_ids: Iterable[str]) -> int:
|
||||
"""Remove multiple recipes from the cache."""
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
removed = await cache.bulk_remove(recipe_ids, resort=False)
|
||||
if removed:
|
||||
self._schedule_resort()
|
||||
return len(removed)
|
||||
|
||||
async def scan_all_recipes(self) -> List[Dict]:
|
||||
"""Scan all recipe JSON files and return metadata"""
|
||||
recipes = []
|
||||
@@ -325,7 +383,6 @@ class RecipeScanner:
|
||||
|
||||
# Calculate and update fingerprint if missing
|
||||
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
|
||||
from ..utils.utils import calculate_recipe_fingerprint
|
||||
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
|
||||
recipe_data['fingerprint'] = fingerprint
|
||||
|
||||
@@ -393,8 +450,8 @@ class RecipeScanner:
|
||||
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
||||
hash_value = lora['hash']
|
||||
|
||||
if self._lora_scanner.has_lora_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
|
||||
if self._lora_scanner.has_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
|
||||
if lora_path:
|
||||
file_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
lora['file_name'] = file_name
|
||||
@@ -431,13 +488,13 @@ class RecipeScanner:
|
||||
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
||||
"""Get hash from Civitai API"""
|
||||
try:
|
||||
# Get CivitaiClient from ServiceRegistry
|
||||
civitai_client = await self._get_civitai_client()
|
||||
if not civitai_client:
|
||||
logger.error("Failed to get CivitaiClient from ServiceRegistry")
|
||||
# Get metadata provider instead of civitai client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
if not metadata_provider:
|
||||
logger.error("Failed to get metadata provider")
|
||||
return None
|
||||
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
||||
version_info, error_msg = await metadata_provider.get_model_version_info(model_version_id)
|
||||
|
||||
if not version_info:
|
||||
if error_msg and "model not found" in error_msg.lower():
|
||||
@@ -465,7 +522,7 @@ class RecipeScanner:
|
||||
# Count occurrences of each base model
|
||||
for lora in loras:
|
||||
if 'hash' in lora:
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
|
||||
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
|
||||
if lora_path:
|
||||
base_model = await self._get_base_model_for_lora(lora_path)
|
||||
if base_model:
|
||||
@@ -496,9 +553,36 @@ class RecipeScanner:
|
||||
logger.error(f"Error getting base model for lora: {e}")
|
||||
return None
|
||||
|
||||
def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Populate convenience fields for a LoRA entry."""
|
||||
|
||||
if not lora or not self._lora_scanner:
|
||||
return lora
|
||||
|
||||
hash_value = (lora.get('hash') or '').lower()
|
||||
if not hash_value:
|
||||
return lora
|
||||
|
||||
try:
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(hash_value)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value)
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.debug("Error enriching lora entry %s: %s", hash_value, exc)
|
||||
|
||||
return lora
|
||||
|
||||
async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Lookup a local LoRA model by name."""
|
||||
|
||||
if not self._lora_scanner or not name:
|
||||
return None
|
||||
|
||||
return await self._lora_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
||||
"""Get paginated and filtered recipe data
|
||||
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
@@ -597,16 +681,12 @@ class RecipeScanner:
|
||||
|
||||
# Get paginated items
|
||||
paginated_items = filtered_data[start_idx:end_idx]
|
||||
|
||||
|
||||
# Add inLibrary information for each lora
|
||||
for item in paginated_items:
|
||||
if 'loras' in item:
|
||||
for lora in item['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
|
||||
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']]
|
||||
|
||||
result = {
|
||||
'items': paginated_items,
|
||||
'total': total_items,
|
||||
@@ -652,13 +732,8 @@ class RecipeScanner:
|
||||
|
||||
# Add lora metadata
|
||||
if 'loras' in formatted_recipe:
|
||||
for lora in formatted_recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora_hash = lora['hash'].lower()
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
|
||||
|
||||
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']]
|
||||
|
||||
return formatted_recipe
|
||||
|
||||
def _format_file_url(self, file_path: str) -> str:
|
||||
@@ -716,26 +791,159 @@ class RecipeScanner:
|
||||
# Save updated recipe
|
||||
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
# Update the cache if it exists
|
||||
if self._cache is not None:
|
||||
await self._cache.update_recipe_metadata(recipe_id, metadata)
|
||||
|
||||
await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False)
|
||||
self._schedule_resort()
|
||||
|
||||
# If the recipe has an image, update its EXIF metadata
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
image_path = recipe_data.get('file_path')
|
||||
if image_path and os.path.exists(image_path):
|
||||
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
||||
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_lora_entry(
|
||||
self,
|
||||
recipe_id: str,
|
||||
lora_index: int,
|
||||
*,
|
||||
target_name: str,
|
||||
target_lora: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Update a specific LoRA entry within a recipe.
|
||||
|
||||
Returns the updated recipe data and the refreshed LoRA metadata.
|
||||
"""
|
||||
|
||||
if target_name is None:
|
||||
raise ValueError("target_name must be provided")
|
||||
|
||||
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
async with self._mutation_lock:
|
||||
with open(recipe_json_path, 'r', encoding='utf-8') as file_obj:
|
||||
recipe_data = json.load(file_obj)
|
||||
|
||||
loras = recipe_data.get('loras', [])
|
||||
if lora_index >= len(loras):
|
||||
raise RecipeNotFoundError("LoRA index out of range in recipe")
|
||||
|
||||
lora_entry = loras[lora_index]
|
||||
lora_entry['isDeleted'] = False
|
||||
lora_entry['exclude'] = False
|
||||
lora_entry['file_name'] = target_name
|
||||
|
||||
if target_lora is not None:
|
||||
sha_value = target_lora.get('sha256') or target_lora.get('sha')
|
||||
if sha_value:
|
||||
lora_entry['hash'] = sha_value.lower()
|
||||
|
||||
civitai_info = target_lora.get('civitai') or {}
|
||||
if civitai_info:
|
||||
lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '')
|
||||
lora_entry['modelVersionName'] = civitai_info.get('name', '')
|
||||
lora_entry['modelVersionId'] = civitai_info.get('id')
|
||||
|
||||
recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', []))
|
||||
recipe_data['modified'] = time.time()
|
||||
|
||||
with open(recipe_json_path, 'w', encoding='utf-8') as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False)
|
||||
if not replaced:
|
||||
await cache.add_recipe(recipe_data, resort=False)
|
||||
self._schedule_resort()
|
||||
|
||||
updated_lora = dict(lora_entry)
|
||||
if target_lora is not None:
|
||||
preview_url = target_lora.get('preview_url')
|
||||
if preview_url:
|
||||
updated_lora['preview_url'] = config.get_preview_static_url(preview_url)
|
||||
if target_lora.get('file_path'):
|
||||
updated_lora['localPath'] = target_lora['file_path']
|
||||
|
||||
updated_lora = self._enrich_lora_entry(updated_lora)
|
||||
return recipe_data, updated_lora
|
||||
|
||||
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
|
||||
"""Return recipes that reference a given LoRA hash."""
|
||||
|
||||
if not lora_hash:
|
||||
return []
|
||||
|
||||
normalized_hash = lora_hash.lower()
|
||||
cache = await self.get_cached_data()
|
||||
matching_recipes: List[Dict[str, Any]] = []
|
||||
|
||||
for recipe in cache.raw_data:
|
||||
loras = recipe.get('loras', [])
|
||||
if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras):
|
||||
recipe_copy = {**recipe}
|
||||
recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras]
|
||||
recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path'))
|
||||
matching_recipes.append(recipe_copy)
|
||||
|
||||
return matching_recipes
|
||||
|
||||
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
|
||||
"""Build LoRA syntax tokens for a recipe."""
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
recipe = await cache.get_recipe(recipe_id)
|
||||
if recipe is None:
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
loras = recipe.get('loras', [])
|
||||
if not loras:
|
||||
return []
|
||||
|
||||
lora_cache = None
|
||||
if self._lora_scanner is not None:
|
||||
lora_cache = await self._lora_scanner.get_cached_data()
|
||||
|
||||
syntax_parts: List[str] = []
|
||||
for lora in loras:
|
||||
if lora.get('isDeleted', False):
|
||||
continue
|
||||
|
||||
file_name = None
|
||||
hash_value = (lora.get('hash') or '').lower()
|
||||
if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'):
|
||||
file_path = self._lora_scanner._hash_index.get_path(hash_value)
|
||||
if file_path:
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
if not file_name and lora.get('modelVersionId') and lora_cache is not None:
|
||||
for cached_lora in getattr(lora_cache, 'raw_data', []):
|
||||
civitai_info = cached_lora.get('civitai')
|
||||
if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'):
|
||||
cached_path = cached_lora.get('path') or cached_lora.get('file_path')
|
||||
if cached_path:
|
||||
file_name = os.path.splitext(os.path.basename(cached_path))[0]
|
||||
break
|
||||
|
||||
if not file_name:
|
||||
file_name = lora.get('file_name', 'unknown-lora')
|
||||
|
||||
strength = lora.get('strength', 1.0)
|
||||
syntax_parts.append(f"<lora:{file_name}:{strength}>")
|
||||
|
||||
return syntax_parts
|
||||
|
||||
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
|
||||
"""Update file_name in all recipes that contain a LoRA with the specified hash.
|
||||
|
||||
|
||||
Args:
|
||||
hash_value: The SHA256 hash value of the LoRA
|
||||
new_file_name: The new file_name to set
|
||||
|
||||
23
py/services/recipes/__init__.py
Normal file
23
py/services/recipes/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Recipe service layer implementations."""
|
||||
|
||||
from .analysis_service import RecipeAnalysisService
|
||||
from .persistence_service import RecipePersistenceService
|
||||
from .sharing_service import RecipeSharingService
|
||||
from .errors import (
|
||||
RecipeServiceError,
|
||||
RecipeValidationError,
|
||||
RecipeNotFoundError,
|
||||
RecipeDownloadError,
|
||||
RecipeConflictError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RecipeAnalysisService",
|
||||
"RecipePersistenceService",
|
||||
"RecipeSharingService",
|
||||
"RecipeServiceError",
|
||||
"RecipeValidationError",
|
||||
"RecipeNotFoundError",
|
||||
"RecipeDownloadError",
|
||||
"RecipeConflictError",
|
||||
]
|
||||
289
py/services/recipes/analysis_service.py
Normal file
289
py/services/recipes/analysis_service.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Services responsible for recipe metadata analysis."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...utils.utils import calculate_recipe_fingerprint
|
||||
from .errors import (
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
RecipeServiceError,
|
||||
RecipeValidationError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnalysisResult:
|
||||
"""Return payload from analysis operations."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
class RecipeAnalysisService:
|
||||
"""Extract recipe metadata from various image sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exif_utils,
|
||||
recipe_parser_factory,
|
||||
downloader_factory: Callable[[], Any],
|
||||
metadata_collector: Optional[Callable[[], Any]] = None,
|
||||
metadata_processor_cls: Optional[type] = None,
|
||||
metadata_registry_cls: Optional[type] = None,
|
||||
standalone_mode: bool = False,
|
||||
logger,
|
||||
) -> None:
|
||||
self._exif_utils = exif_utils
|
||||
self._recipe_parser_factory = recipe_parser_factory
|
||||
self._downloader_factory = downloader_factory
|
||||
self._metadata_collector = metadata_collector
|
||||
self._metadata_processor_cls = metadata_processor_cls
|
||||
self._metadata_registry_cls = metadata_registry_cls
|
||||
self._standalone_mode = standalone_mode
|
||||
self._logger = logger
|
||||
|
||||
async def analyze_uploaded_image(
|
||||
self,
|
||||
*,
|
||||
image_bytes: bytes | None,
|
||||
recipe_scanner,
|
||||
) -> AnalysisResult:
|
||||
"""Analyze an uploaded image payload."""
|
||||
|
||||
if not image_bytes:
|
||||
raise RecipeValidationError("No image data provided")
|
||||
|
||||
temp_path = self._write_temp_file(image_bytes)
|
||||
try:
|
||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||
if not metadata:
|
||||
return AnalysisResult({"error": "No metadata found in this image", "loras": []})
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=None,
|
||||
include_image_base64=False,
|
||||
)
|
||||
finally:
|
||||
self._safe_cleanup(temp_path)
|
||||
|
||||
async def analyze_remote_image(
|
||||
self,
|
||||
*,
|
||||
url: str | None,
|
||||
recipe_scanner,
|
||||
civitai_client,
|
||||
) -> AnalysisResult:
|
||||
"""Analyze an image accessible via URL, including Civitai integration."""
|
||||
|
||||
if not url:
|
||||
raise RecipeValidationError("No URL provided")
|
||||
|
||||
if civitai_client is None:
|
||||
raise RecipeServiceError("Civitai client unavailable")
|
||||
|
||||
temp_path = self._create_temp_path()
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
try:
|
||||
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url)
|
||||
if civitai_match:
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
image_url = image_info.get("url")
|
||||
if not image_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
await self._download_image(image_url, temp_path)
|
||||
metadata = image_info.get("meta") if "meta" in image_info else None
|
||||
else:
|
||||
await self._download_image(url, temp_path)
|
||||
|
||||
if metadata is None:
|
||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||
|
||||
if not metadata:
|
||||
return self._metadata_not_found_response(temp_path)
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=temp_path,
|
||||
include_image_base64=True,
|
||||
)
|
||||
finally:
|
||||
self._safe_cleanup(temp_path)
|
||||
|
||||
async def analyze_local_image(
|
||||
self,
|
||||
*,
|
||||
file_path: str | None,
|
||||
recipe_scanner,
|
||||
) -> AnalysisResult:
|
||||
"""Analyze a file already present on disk."""
|
||||
|
||||
if not file_path:
|
||||
raise RecipeValidationError("No file path provided")
|
||||
|
||||
normalized_path = os.path.normpath(file_path.strip('"').strip("'"))
|
||||
if not os.path.isfile(normalized_path):
|
||||
raise RecipeNotFoundError("File not found")
|
||||
|
||||
metadata = self._exif_utils.extract_image_metadata(normalized_path)
|
||||
if not metadata:
|
||||
return self._metadata_not_found_response(normalized_path)
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=normalized_path,
|
||||
include_image_base64=True,
|
||||
)
|
||||
|
||||
async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult:
|
||||
"""Analyse the most recent generation metadata for widget saves."""
|
||||
|
||||
if self._metadata_collector is None or self._metadata_processor_cls is None:
|
||||
raise RecipeValidationError("Metadata collection not available")
|
||||
|
||||
raw_metadata = self._metadata_collector()
|
||||
metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata)
|
||||
if not metadata_dict:
|
||||
raise RecipeValidationError("No generation metadata found")
|
||||
|
||||
latest_image = None
|
||||
if not self._standalone_mode and self._metadata_registry_cls is not None:
|
||||
metadata_registry = self._metadata_registry_cls()
|
||||
latest_image = metadata_registry.get_first_decoded_image()
|
||||
|
||||
if latest_image is None:
|
||||
raise RecipeValidationError(
|
||||
"No recent images found to use for recipe. Try generating an image first."
|
||||
)
|
||||
|
||||
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
|
||||
if image_bytes is None:
|
||||
raise RecipeValidationError("Cannot handle this data shape from metadata registry")
|
||||
|
||||
return AnalysisResult(
|
||||
{
|
||||
"metadata": metadata_dict,
|
||||
"image_bytes": image_bytes,
|
||||
}
|
||||
)
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
async def _parse_metadata(
|
||||
self,
|
||||
metadata: dict[str, Any],
|
||||
*,
|
||||
recipe_scanner,
|
||||
image_path: Optional[str],
|
||||
include_image_base64: bool,
|
||||
) -> AnalysisResult:
|
||||
parser = self._recipe_parser_factory.create_parser(metadata)
|
||||
if parser is None:
|
||||
payload = {"error": "No parser found for this image", "loras": []}
|
||||
if include_image_base64 and image_path:
|
||||
payload["image_base64"] = self._encode_file(image_path)
|
||||
return AnalysisResult(payload)
|
||||
|
||||
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
|
||||
|
||||
if include_image_base64 and image_path:
|
||||
result["image_base64"] = self._encode_file(image_path)
|
||||
|
||||
if "error" in result and not result.get("loras"):
|
||||
return AnalysisResult(result)
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(result.get("loras", []))
|
||||
result["fingerprint"] = fingerprint
|
||||
|
||||
matching_recipes: list[str] = []
|
||||
if fingerprint:
|
||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||
result["matching_recipes"] = matching_recipes
|
||||
|
||||
return AnalysisResult(result)
|
||||
|
||||
async def _download_image(self, url: str, temp_path: str) -> None:
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(url, temp_path, use_auth=False)
|
||||
if not success:
|
||||
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
|
||||
|
||||
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
|
||||
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []}
|
||||
if os.path.exists(path):
|
||||
payload["image_base64"] = self._encode_file(path)
|
||||
return AnalysisResult(payload)
|
||||
|
||||
def _write_temp_file(self, data: bytes) -> str:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
|
||||
def _create_temp_path(self) -> str:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
return temp_file.name
|
||||
|
||||
def _safe_cleanup(self, path: Optional[str]) -> None:
|
||||
if path and os.path.exists(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Error deleting temporary file: %s", exc)
|
||||
|
||||
def _encode_file(self, path: str) -> str:
|
||||
with open(path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]:
|
||||
try:
|
||||
if isinstance(latest_image, tuple):
|
||||
tensor_image = latest_image[0] if latest_image else None
|
||||
if tensor_image is None:
|
||||
return None
|
||||
else:
|
||||
tensor_image = latest_image
|
||||
|
||||
if hasattr(tensor_image, "shape"):
|
||||
self._logger.debug(
|
||||
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None)
|
||||
)
|
||||
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
if isinstance(tensor_image, torch.Tensor):
|
||||
image_np = tensor_image.cpu().numpy()
|
||||
else:
|
||||
image_np = np.array(tensor_image)
|
||||
|
||||
while len(image_np.shape) > 3:
|
||||
image_np = image_np[0]
|
||||
|
||||
if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
|
||||
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
|
||||
pil_image = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
pil_image.save(img_byte_arr, format="PNG")
|
||||
return img_byte_arr.getvalue()
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
self._logger.error("Error processing image data: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
return None
|
||||
22
py/services/recipes/errors.py
Normal file
22
py/services/recipes/errors.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Shared exceptions for recipe services."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class RecipeServiceError(Exception):
|
||||
"""Base exception for recipe service failures."""
|
||||
|
||||
|
||||
class RecipeValidationError(RecipeServiceError):
|
||||
"""Raised when a request payload fails validation."""
|
||||
|
||||
|
||||
class RecipeNotFoundError(RecipeServiceError):
|
||||
"""Raised when a recipe resource cannot be located."""
|
||||
|
||||
|
||||
class RecipeDownloadError(RecipeServiceError):
|
||||
"""Raised when remote recipe assets cannot be downloaded."""
|
||||
|
||||
|
||||
class RecipeConflictError(RecipeServiceError):
|
||||
"""Raised when a conflicting recipe state is detected."""
|
||||
400
py/services/recipes/persistence_service.py
Normal file
400
py/services/recipes/persistence_service.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Services encapsulating recipe persistence workflows."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from ...config import config
|
||||
from ...utils.utils import calculate_recipe_fingerprint
|
||||
from .errors import RecipeNotFoundError, RecipeValidationError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PersistenceResult:
|
||||
"""Return payload from persistence operations."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
class RecipePersistenceService:
|
||||
"""Coordinate recipe persistence tasks across storage and caches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exif_utils,
|
||||
card_preview_width: int,
|
||||
logger,
|
||||
) -> None:
|
||||
self._exif_utils = exif_utils
|
||||
self._card_preview_width = card_preview_width
|
||||
self._logger = logger
|
||||
|
||||
async def save_recipe(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
image_bytes: bytes | None,
|
||||
image_base64: str | None,
|
||||
name: str | None,
|
||||
tags: Iterable[str],
|
||||
metadata: Optional[dict[str, Any]],
|
||||
) -> PersistenceResult:
|
||||
"""Persist a user uploaded recipe."""
|
||||
|
||||
missing_fields = []
|
||||
if not name:
|
||||
missing_fields.append("name")
|
||||
if metadata is None:
|
||||
missing_fields.append("metadata")
|
||||
if missing_fields:
|
||||
raise RecipeValidationError(
|
||||
f"Missing required fields: {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64)
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
optimized_image, extension = self._exif_utils.optimize_image(
|
||||
image_data=resolved_image_bytes,
|
||||
target_width=self._card_preview_width,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=True,
|
||||
)
|
||||
image_filename = f"{recipe_id}{extension}"
|
||||
image_path = os.path.join(recipes_dir, image_filename)
|
||||
with open(image_path, "wb") as file_obj:
|
||||
file_obj.write(optimized_image)
|
||||
|
||||
current_time = time.time()
|
||||
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
|
||||
|
||||
gen_params = metadata.get("gen_params", {})
|
||||
if not gen_params and "raw_metadata" in metadata:
|
||||
raw_metadata = metadata.get("raw_metadata", {})
|
||||
gen_params = {
|
||||
"prompt": raw_metadata.get("prompt", ""),
|
||||
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
||||
"checkpoint": raw_metadata.get("checkpoint", {}),
|
||||
"steps": raw_metadata.get("steps", ""),
|
||||
"sampler": raw_metadata.get("sampler", ""),
|
||||
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
||||
"seed": raw_metadata.get("seed", ""),
|
||||
"size": raw_metadata.get("size", ""),
|
||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||
}
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||
recipe_data: Dict[str, Any] = {
|
||||
"id": recipe_id,
|
||||
"file_path": image_path,
|
||||
"title": name,
|
||||
"modified": current_time,
|
||||
"created_date": current_time,
|
||||
"base_model": metadata.get("base_model", ""),
|
||||
"loras": loras_data,
|
||||
"gen_params": gen_params,
|
||||
"fingerprint": fingerprint,
|
||||
}
|
||||
|
||||
tags_list = list(tags)
|
||||
if tags_list:
|
||||
recipe_data["tags"] = tags_list
|
||||
|
||||
if metadata.get("source_path"):
|
||||
recipe_data["source_path"] = metadata.get("source_path")
|
||||
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||
|
||||
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
|
||||
await recipe_scanner.add_recipe(recipe_data)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"image_path": image_path,
|
||||
"json_path": json_path,
|
||||
"matching_recipes": matching_recipes,
|
||||
}
|
||||
)
|
||||
|
||||
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
||||
"""Delete an existing recipe."""
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||
recipe_data = json.load(file_obj)
|
||||
|
||||
image_path = recipe_data.get("file_path")
|
||||
os.remove(recipe_json_path)
|
||||
if image_path and os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
await recipe_scanner.remove_recipe(recipe_id)
|
||||
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
|
||||
|
||||
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
|
||||
"""Update persisted metadata for a recipe."""
|
||||
|
||||
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
|
||||
raise RecipeValidationError(
|
||||
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
|
||||
)
|
||||
|
||||
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
|
||||
if not success:
|
||||
raise RecipeNotFoundError("Recipe not found or update failed")
|
||||
|
||||
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
|
||||
|
||||
async def reconnect_lora(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
recipe_id: str,
|
||||
lora_index: int,
|
||||
target_name: str,
|
||||
) -> PersistenceResult:
|
||||
"""Reconnect a LoRA entry within an existing recipe."""
|
||||
|
||||
recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
target_lora = await recipe_scanner.get_local_lora(target_name)
|
||||
if not target_lora:
|
||||
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
|
||||
|
||||
recipe_data, updated_lora = await recipe_scanner.update_lora_entry(
|
||||
recipe_id,
|
||||
lora_index,
|
||||
target_name=target_name,
|
||||
target_lora=target_lora,
|
||||
)
|
||||
|
||||
image_path = recipe_data.get("file_path")
|
||||
if image_path and os.path.exists(image_path):
|
||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||
|
||||
matching_recipes = []
|
||||
if "fingerprint" in recipe_data:
|
||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"])
|
||||
if recipe_id in matching_recipes:
|
||||
matching_recipes.remove(recipe_id)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"updated_lora": updated_lora,
|
||||
"matching_recipes": matching_recipes,
|
||||
}
|
||||
)
|
||||
|
||||
async def bulk_delete(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
recipe_ids: Iterable[str],
|
||||
) -> PersistenceResult:
|
||||
"""Delete multiple recipes in a single request."""
|
||||
|
||||
recipe_ids = list(recipe_ids)
|
||||
if not recipe_ids:
|
||||
raise RecipeValidationError("No recipe IDs provided")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
deleted_recipes: list[str] = []
|
||||
failed_recipes: list[dict[str, Any]] = []
|
||||
|
||||
for recipe_id in recipe_ids:
|
||||
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||
recipe_data = json.load(file_obj)
|
||||
image_path = recipe_data.get("file_path")
|
||||
os.remove(recipe_json_path)
|
||||
if image_path and os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
deleted_recipes.append(recipe_id)
|
||||
except Exception as exc:
|
||||
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
|
||||
|
||||
if deleted_recipes:
|
||||
await recipe_scanner.bulk_remove(deleted_recipes)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"deleted": deleted_recipes,
|
||||
"failed": failed_recipes,
|
||||
"total_deleted": len(deleted_recipes),
|
||||
"total_failed": len(failed_recipes),
|
||||
}
|
||||
)
|
||||
|
||||
async def save_recipe_from_widget(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
metadata: dict[str, Any],
|
||||
image_bytes: bytes,
|
||||
) -> PersistenceResult:
|
||||
"""Save a recipe constructed from widget metadata."""
|
||||
|
||||
if not metadata:
|
||||
raise RecipeValidationError("No generation metadata found")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
image_filename = f"{recipe_id}.png"
|
||||
image_path = os.path.join(recipes_dir, image_filename)
|
||||
with open(image_path, "wb") as file_obj:
|
||||
file_obj.write(image_bytes)
|
||||
|
||||
lora_stack = metadata.get("loras", "")
|
||||
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack)
|
||||
if not lora_matches:
|
||||
raise RecipeValidationError("No LoRAs found in the generation metadata")
|
||||
|
||||
loras_data = []
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
|
||||
for name, strength in lora_matches:
|
||||
lora_info = await recipe_scanner.get_local_lora(name)
|
||||
lora_data = {
|
||||
"file_name": name,
|
||||
"strength": float(strength),
|
||||
"hash": (lora_info.get("sha256") or "").lower() if lora_info else "",
|
||||
"modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0,
|
||||
"modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "",
|
||||
"modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "",
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
loras_data.append(lora_data)
|
||||
|
||||
if lora_info and "base_model" in lora_info:
|
||||
base_model = lora_info["base_model"]
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
recipe_name = self._derive_recipe_name(lora_matches)
|
||||
most_common_base_model = (
|
||||
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
||||
)
|
||||
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
"file_path": image_path,
|
||||
"title": recipe_name,
|
||||
"modified": time.time(),
|
||||
"created_date": time.time(),
|
||||
"base_model": most_common_base_model,
|
||||
"loras": loras_data,
|
||||
"checkpoint": metadata.get("checkpoint", ""),
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata.items()
|
||||
if key not in ["checkpoint", "loras"]
|
||||
},
|
||||
"loras_stack": lora_stack,
|
||||
}
|
||||
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||
await recipe_scanner.add_recipe(recipe_data)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"image_path": image_path,
|
||||
"json_path": json_path,
|
||||
"recipe_name": recipe_name,
|
||||
}
|
||||
)
|
||||
|
||||
# Helper methods ---------------------------------------------------
|
||||
|
||||
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
|
||||
if image_bytes is not None:
|
||||
return image_bytes
|
||||
if image_base64:
|
||||
try:
|
||||
payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64
|
||||
return base64.b64decode(payload)
|
||||
except Exception as exc: # pragma: no cover - validation guard
|
||||
raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc
|
||||
raise RecipeValidationError("No image data provided")
|
||||
|
||||
def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"file_name": lora.get("file_name", "")
|
||||
or (
|
||||
os.path.splitext(os.path.basename(lora.get("localPath", "")))[0]
|
||||
if lora.get("localPath")
|
||||
else ""
|
||||
),
|
||||
"hash": (lora.get("hash") or "").lower(),
|
||||
"strength": float(lora.get("weight", 1.0)),
|
||||
"modelVersionId": lora.get("id", 0),
|
||||
"modelName": lora.get("name", ""),
|
||||
"modelVersionName": lora.get("version", ""),
|
||||
"isDeleted": lora.get("isDeleted", False),
|
||||
"exclude": lora.get("exclude", False),
|
||||
}
|
||||
|
||||
async def _find_matching_recipes(
|
||||
self,
|
||||
recipe_scanner,
|
||||
fingerprint: str | None,
|
||||
*,
|
||||
exclude_id: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
if not fingerprint:
|
||||
return []
|
||||
matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||
if exclude_id and exclude_id in matches:
|
||||
matches.remove(exclude_id)
|
||||
return matches
|
||||
|
||||
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
|
||||
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
|
||||
recipe_name = "_".join(recipe_name_parts)
|
||||
return recipe_name or "recipe"
|
||||
105
py/services/recipes/sharing_service.py
Normal file
105
py/services/recipes/sharing_service.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Services handling recipe sharing and downloads."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
from .errors import RecipeNotFoundError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SharingResult:
|
||||
"""Return payload for share operations."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadInfo:
|
||||
"""Information required to stream a shared recipe file."""
|
||||
|
||||
file_path: str
|
||||
download_filename: str
|
||||
|
||||
|
||||
class RecipeSharingService:
|
||||
"""Prepare temporary recipe downloads with TTL cleanup."""
|
||||
|
||||
def __init__(self, *, ttl_seconds: int = 300, logger) -> None:
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._logger = logger
|
||||
self._shared_recipes: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult:
|
||||
"""Prepare a temporary downloadable copy of a recipe image."""
|
||||
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if not recipe:
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
image_path = recipe.get("file_path")
|
||||
if not image_path or not os.path.exists(image_path):
|
||||
raise RecipeNotFoundError("Recipe image not found")
|
||||
|
||||
ext = os.path.splitext(image_path)[1]
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
shutil.copy2(image_path, temp_path)
|
||||
timestamp = int(time.time())
|
||||
self._shared_recipes[recipe_id] = {
|
||||
"path": temp_path,
|
||||
"timestamp": timestamp,
|
||||
"expires": time.time() + self._ttl_seconds,
|
||||
}
|
||||
self._cleanup_shared_recipes()
|
||||
|
||||
safe_title = recipe.get("title", "").replace(" ", "_").lower()
|
||||
filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}"
|
||||
url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}"
|
||||
return SharingResult({"success": True, "download_url": url_path, "filename": filename})
|
||||
|
||||
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> DownloadInfo:
|
||||
"""Return file path and filename for a prepared shared recipe."""
|
||||
|
||||
shared_info = self._shared_recipes.get(recipe_id)
|
||||
if not shared_info or time.time() > shared_info.get("expires", 0):
|
||||
self._cleanup_entry(recipe_id)
|
||||
raise RecipeNotFoundError("Shared recipe not found or expired")
|
||||
|
||||
file_path = shared_info["path"]
|
||||
if not os.path.exists(file_path):
|
||||
self._cleanup_entry(recipe_id)
|
||||
raise RecipeNotFoundError("Shared recipe file not found")
|
||||
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
filename_base = (
|
||||
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
|
||||
)
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
download_filename = f"{filename_base}{ext}"
|
||||
return DownloadInfo(file_path=file_path, download_filename=download_filename)
|
||||
|
||||
def _cleanup_shared_recipes(self) -> None:
|
||||
for recipe_id in list(self._shared_recipes.keys()):
|
||||
shared = self._shared_recipes.get(recipe_id)
|
||||
if not shared:
|
||||
continue
|
||||
if time.time() > shared.get("expires", 0):
|
||||
self._cleanup_entry(recipe_id)
|
||||
|
||||
def _cleanup_entry(self, recipe_id: str) -> None:
|
||||
shared_info = self._shared_recipes.pop(recipe_id, None)
|
||||
if not shared_info:
|
||||
return
|
||||
file_path = shared_info.get("path")
|
||||
if file_path and os.path.exists(file_path):
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user