mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
115 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68c5f79a67 | ||
|
|
836a64e728 | ||
|
|
08ba0c9f42 | ||
|
|
6fcc6a5299 | ||
|
|
6dd58248c6 | ||
|
|
2786801b71 | ||
|
|
ea29cbeb7a | ||
|
|
3cf9121a8c | ||
|
|
381bd3938a | ||
|
|
e4ce384023 | ||
|
|
12d1857b13 | ||
|
|
0d9003dea4 | ||
|
|
1a3751acfa | ||
|
|
c5a3af2399 | ||
|
|
ea8a64fafc | ||
|
|
981e367bf1 | ||
|
|
a3d6e62035 | ||
|
|
7f205cdcc8 | ||
|
|
e587189880 | ||
|
|
206c1bd69f | ||
|
|
a7d9255c2c | ||
|
|
08265a85ec | ||
|
|
1ed5630464 | ||
|
|
c784615f11 | ||
|
|
26d51b1190 | ||
|
|
d83fad6abc | ||
|
|
692796db46 | ||
|
|
f15c6f33f9 | ||
|
|
dda9eb4d7c | ||
|
|
6f3aeb61e7 | ||
|
|
d6145e633f | ||
|
|
07014d98ce | ||
|
|
e8ccdabe6c | ||
|
|
cf9fd2d5c2 | ||
|
|
bf9aa9356b | ||
|
|
68d00ce289 | ||
|
|
5288021e4f | ||
|
|
4d38add291 | ||
|
|
804808da4a | ||
|
|
298a95432d | ||
|
|
a834fc4b30 | ||
|
|
2c6c9542dd | ||
|
|
a9a7f4c8ec | ||
|
|
ea9370443d | ||
|
|
c2e00b240e | ||
|
|
a2b81ea099 | ||
|
|
ee609e8eac | ||
|
|
e04ef671e9 | ||
|
|
0184dfd7eb | ||
|
|
eccfa0ca54 | ||
|
|
6d3feb4bef | ||
|
|
29d2b5ee4b | ||
|
|
c82fabb67f | ||
|
|
fcfc868e57 | ||
|
|
67b403f8ca | ||
|
|
de06c6b2f6 | ||
|
|
fa444dfb8a | ||
|
|
124002a472 | ||
|
|
0c883433c1 | ||
|
|
bcf3b2cf55 | ||
|
|
357c4e9c08 | ||
|
|
9edfc68e91 | ||
|
|
8c06cb3e80 | ||
|
|
144fa0a6d4 | ||
|
|
25d5a1541e | ||
|
|
a579d36389 | ||
|
|
d766dac341 | ||
|
|
b15ef1bbc6 | ||
|
|
3e52e00597 | ||
|
|
f749dd0d52 | ||
|
|
48a8a42108 | ||
|
|
db7f57a5a4 | ||
|
|
556381b983 | ||
|
|
158d7d5898 | ||
|
|
18844da95d | ||
|
|
7e0df4d718 | ||
|
|
0dbb76e8c8 | ||
|
|
f73b3422a6 | ||
|
|
bd95e802ec | ||
|
|
5de16a78c5 | ||
|
|
6f8e09fcde | ||
|
|
f54d480f03 | ||
|
|
e68b213fb3 | ||
|
|
132334d500 | ||
|
|
a6f04c6d7e | ||
|
|
854e8bf356 | ||
|
|
6ff883d2d3 | ||
|
|
849b97afba | ||
|
|
1bd2635864 | ||
|
|
79ab0f7b6c | ||
|
|
79011bd257 | ||
|
|
c692713ffb | ||
|
|
df9b554ce1 | ||
|
|
277a8e4682 | ||
|
|
acb52dba09 | ||
|
|
8f10765254 | ||
|
|
0653f59473 | ||
|
|
7a4b5a4667 | ||
|
|
49c4a4068b | ||
|
|
40ad590046 | ||
|
|
30374ae3e6 | ||
|
|
ab22d16bad | ||
|
|
971cd56a4a | ||
|
|
d7cb546c5f | ||
|
|
9d8b7344cd | ||
|
|
2d4f6ae7ce | ||
|
|
d9126807b0 | ||
|
|
cad5fb3fba | ||
|
|
afe23ad6b7 | ||
|
|
fc4327087b | ||
|
|
71762d788f | ||
|
|
6472e00fb0 | ||
|
|
4043846767 | ||
|
|
d3b2bc962c | ||
|
|
54f7b64821 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
__pycache__/
|
||||
settings.json
|
||||
path_mappings.yaml
|
||||
output/*
|
||||
py/run_test.py
|
||||
.vscode/
|
||||
|
||||
25
README.md
25
README.md
@@ -10,18 +10,35 @@ A comprehensive toolset that streamlines organizing, downloading, and applying L
|
||||
|
||||

|
||||
|
||||
One-click Integration:
|
||||

|
||||
|
||||
## 📺 Tutorial: One-Click LoRA Integration
|
||||
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
|
||||
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
|
||||
## 🌐 Browser Extension
|
||||
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
|
||||
|
||||

|
||||
|
||||
<div>
|
||||
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
|
||||
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div>
|
||||
|
||||
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
|
||||
|
||||
---
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.8.20
|
||||
* **LM Civitai Extension** - Released [browser extension through Chrome Web Store](https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) that works seamlessly with LoRA Manager to enhance Civitai browsing experience, showing which models are already in your local library, enabling one-click downloads, and providing queue and parallel download support
|
||||
* **Enhanced Lora Loader** - Added support for nunchaku, improving convenience when working with ComfyUI-nunchaku workflows, plus new template workflows for quick onboarding
|
||||
* **WanVideo Integration** - Introduced WanVideo Lora Select (LoraManager) node compatible with ComfyUI-WanVideoWrapper for streamlined lora usage in video workflows, including a template workflow to help you get started quickly
|
||||
|
||||
### v0.8.19
|
||||
* **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
|
||||
* **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting
|
||||
|
||||
@@ -4,6 +4,7 @@ from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
from .py.nodes.debug_metadata import DebugMetadata
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||
# Import metadata collector to install hooks on startup
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
|
||||
@@ -12,7 +13,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||
LoraStacker.NAME: LoraStacker,
|
||||
SaveImage.NAME: SaveImage,
|
||||
DebugMetadata.NAME: DebugMetadata
|
||||
DebugMetadata.NAME: DebugMetadata,
|
||||
WanVideoLoraSelect.NAME: WanVideoLoraSelect
|
||||
}
|
||||
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
|
||||
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
179
py/config.py
179
py/config.py
@@ -17,13 +17,17 @@ class Config:
|
||||
def __init__(self):
|
||||
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
||||
# 路径映射字典, target to link mapping
|
||||
# Path mapping dictionary, target to link mapping
|
||||
self._path_mappings = {}
|
||||
# 静态路由映射字典, target to route mapping
|
||||
# Static route mapping dictionary, target to route mapping
|
||||
self._route_mappings = {}
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.checkpoints_roots = self._init_checkpoint_paths()
|
||||
# 在初始化时扫描符号链接
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
self.embeddings_roots = None
|
||||
self.base_models_roots = self._init_checkpoint_paths()
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._scan_symbolic_links()
|
||||
|
||||
if not standalone_mode:
|
||||
@@ -33,34 +37,34 @@ class Config:
|
||||
def save_folder_paths_to_settings(self):
|
||||
"""Save folder paths to settings.json for standalone mode to use later"""
|
||||
try:
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
|
||||
# Get all relevant paths
|
||||
lora_paths = folder_paths.get_folder_paths("loras")
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffuser_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': lora_paths,
|
||||
'checkpoints': checkpoint_paths,
|
||||
'diffusers': diffuser_paths,
|
||||
'unet': unet_paths
|
||||
}
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': self.loras_roots,
|
||||
'checkpoints': self.checkpoints_roots,
|
||||
'unet': self.unet_roots,
|
||||
'embeddings': self.embeddings_roots,
|
||||
}
|
||||
|
||||
# Add default roots if there's only one item and key doesn't exist
|
||||
if len(self.loras_roots) == 1 and "default_lora_root" not in settings:
|
||||
settings["default_lora_root"] = self.loras_roots[0]
|
||||
|
||||
if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings:
|
||||
settings["default_checkpoint_root"] = self.checkpoints_roots[0]
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
@@ -82,15 +86,18 @@ class Config:
|
||||
return False
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""扫描所有 LoRA 和 Checkpoint 根目录中的符号链接"""
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
for root in self.loras_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.checkpoints_roots:
|
||||
for root in self.base_models_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.embeddings_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
def _scan_directory_links(self, root: str):
|
||||
"""递归扫描目录中的符号链接"""
|
||||
"""Recursively scan symbolic links in a directory"""
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
@@ -105,40 +112,40 @@ class Config:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
|
||||
def add_path_mapping(self, link_path: str, target_path: str):
|
||||
"""添加符号链接路径映射
|
||||
target_path: 实际目标路径
|
||||
link_path: 符号链接路径
|
||||
"""Add a symbolic link path mapping
|
||||
target_path: actual target path
|
||||
link_path: symbolic link path
|
||||
"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||
# 保持原有的映射关系:目标路径 -> 链接路径
|
||||
# Keep the original mapping: target path -> link path
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
|
||||
def add_route_mapping(self, path: str, route: str):
|
||||
"""添加静态路由映射"""
|
||||
"""Add a static route mapping"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
self._route_mappings[normalized_path] = route
|
||||
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
|
||||
def map_path_to_link(self, path: str) -> str:
|
||||
"""将目标路径映射回符号链接路径"""
|
||||
"""Map a target path back to its symbolic link path"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_path.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为链接路径
|
||||
# If the path starts with the target path, replace with link path
|
||||
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return path
|
||||
|
||||
def map_link_to_path(self, link_path: str) -> str:
|
||||
"""将符号链接路径映射回实际路径"""
|
||||
"""Map a symbolic link path back to the actual path"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_link.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为实际路径
|
||||
# If the path starts with the target path, replace with actual path
|
||||
mapped_path = normalized_link.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return link_path
|
||||
@@ -177,35 +184,81 @@ class Config:
|
||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||
try:
|
||||
# Get checkpoint paths from folder_paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Combine all checkpoint-related paths
|
||||
all_paths = checkpoint_paths + diffusion_paths + unet_paths
|
||||
# Normalize and resolve symlinks for checkpoints, store mapping from resolved -> original
|
||||
checkpoint_map = {}
|
||||
for path in raw_checkpoint_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
checkpoint_map[real_path] = checkpoint_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Filter and normalize paths
|
||||
paths = sorted(set(path.replace(os.sep, "/")
|
||||
for path in all_paths
|
||||
if os.path.exists(path)), key=lambda p: p.lower())
|
||||
# Normalize and resolve symlinks for unet, store mapping from resolved -> original
|
||||
unet_map = {}
|
||||
for path in raw_unet_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_checkpoint_paths = sorted(checkpoint_map.values(), key=lambda p: p.lower())
|
||||
unique_unet_paths = sorted(unet_map.values(), key=lambda p: p.lower())
|
||||
|
||||
if not paths:
|
||||
# Store individual paths in class properties
|
||||
self.checkpoints_roots = unique_checkpoint_paths
|
||||
self.unet_roots = unique_unet_paths
|
||||
|
||||
# Combine all checkpoint-related paths for return value
|
||||
all_paths = unique_checkpoint_paths + unique_unet_paths
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
|
||||
|
||||
if not all_paths:
|
||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
# 初始化路径映射,与 LoRA 路径处理方式相同
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
# Initialize path mappings
|
||||
for original_path in all_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return paths
|
||||
return all_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing checkpoint paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_embedding_paths(self) -> List[str]:
|
||||
"""Initialize and validate embedding paths from ComfyUI settings"""
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("embeddings")
|
||||
|
||||
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||
path_map = {}
|
||||
for path in raw_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid embeddings folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing embedding paths: {e}")
|
||||
return []
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
"""Convert local preview path to static URL"""
|
||||
if not preview_path:
|
||||
|
||||
@@ -6,10 +6,8 @@ from pathlib import Path
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
@@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,12 +27,28 @@ class LoraManager:
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
"""Initialize and register all routes"""
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
@@ -62,7 +77,7 @@ class LoraManager:
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
@@ -79,21 +94,45 @@ class LoraManager:
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(config.embeddings_roots, start=1):
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for symlink target paths
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
if target_path not in added_targets:
|
||||
# Determine if this is a checkpoint or lora link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
|
||||
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
|
||||
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
@@ -110,35 +149,37 @@ class LoraManager:
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
stats_routes = StatsRoutes()
|
||||
# Register default model types with the factory
|
||||
register_default_model_types()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
stats_routes.setup_routes(app) # Add statistics routes
|
||||
ApiRoutes.setup_routes(app)
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
# Setup non-model-specific routes
|
||||
stats_routes = StatsRoutes()
|
||||
stats_routes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app) # Register miscellaneous routes
|
||||
ExampleImagesRoutes.setup_routes(app) # Register example images routes
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||
|
||||
@classmethod
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
# Ensure aiohttp access logger is configured with reduced verbosity
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
@@ -151,19 +192,15 @@ class LoraManager:
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Initialize metadata collector if not in standalone mode
|
||||
if not STANDALONE_MODE:
|
||||
from .metadata_collector import init as init_metadata
|
||||
init_metadata()
|
||||
logger.debug("Metadata collector initialized")
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init')
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
|
||||
await ExampleImagesMigration.check_and_run_migrations()
|
||||
|
||||
@@ -26,98 +26,191 @@ class MetadataHook:
|
||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
return
|
||||
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
# Detect whether we're using the new async version of ComfyUI
|
||||
is_async = False
|
||||
map_node_func_name = '_map_node_over_list'
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
if hasattr(execution, '_async_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
|
||||
map_node_func_name = '_async_map_node_over_list'
|
||||
elif hasattr(execution, '_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
# Make map_node_over_list public to avoid it being hidden by hooks
|
||||
execution.map_node_over_list = original_map_node_over_list
|
||||
if is_async:
|
||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||
else:
|
||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
MetadataHook._install_sync_hooks(execution)
|
||||
|
||||
print("Metadata collection hooks installed for runtime values")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error installing metadata hooks: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _install_sync_hooks(execution):
|
||||
"""Install hooks for synchronous execution model"""
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
|
||||
@staticmethod
|
||||
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
|
||||
"""Install hooks for asynchronous execution model"""
|
||||
# Store the original _async_map_node_over_list function
|
||||
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||
|
||||
# Define the wrapped async function - NOTE: Updated signature with prompt_id and unique_id!
|
||||
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
# We now have prompt_id directly from the function parameters
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Use the passed unique_id parameter instead of trying to extract it
|
||||
node_id = unique_id
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original async function with ALL parameters in the correct order
|
||||
results = await original_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Use the passed unique_id parameter
|
||||
node_id = unique_id
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return await original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions with async versions
|
||||
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||
execution.execute = async_execute_with_prompt_tracking
|
||||
|
||||
@@ -238,25 +238,45 @@ class MetadataProcessor:
|
||||
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Try to match conditioning objects with those stored by CLIPTextEncodeExtractor
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning):
|
||||
result["prompt"] = prompt_data.get("text", "")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
|
||||
if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning):
|
||||
result["negative_prompt"] = prompt_data.get("text", "")
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if "positive_encoded" in prompt_data:
|
||||
if pos_conditioning is not None and id(prompt_data["positive_encoded"]) == id(pos_conditioning):
|
||||
result["prompt"] = prompt_data.get("positive_text", "")
|
||||
|
||||
if "negative_encoded" in prompt_data:
|
||||
if neg_conditioning is not None and id(prompt_data["negative_encoded"]) == id(neg_conditioning):
|
||||
result["negative_prompt"] = prompt_data.get("negative_text", "")
|
||||
return ""
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -117,15 +117,15 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
class SamplerExtractor(NodeMetadataExtractor):
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
def extract_sampling_params(node_id, inputs, metadata, param_keys):
|
||||
"""Extract sampling parameters from inputs"""
|
||||
sampling_params = {}
|
||||
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
|
||||
for key in param_keys:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
@@ -134,7 +134,10 @@ class SamplerExtractor(NodeMetadataExtractor):
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_conditioning(node_id, inputs, metadata):
|
||||
"""Extract conditioning objects from inputs"""
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
@@ -146,7 +149,10 @@ class SamplerExtractor(NodeMetadataExtractor):
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_latent_dimensions(node_id, inputs, metadata):
|
||||
"""Extract dimensions from latent image"""
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
@@ -167,59 +173,106 @@ class SamplerExtractor(NodeMetadataExtractor):
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
|
||||
|
||||
class SamplerExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
# Save conditioning objects in metadata for later matching
|
||||
if pos_conditioning is not None or neg_conditioning is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
class KSamplerAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerAdvancedBasicPipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for handling TSC sampler node outputs"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Store vae_decode setting for later use in update
|
||||
@@ -273,7 +326,6 @@ class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
"""Extractor for TSC_KSampler nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
@@ -284,11 +336,10 @@ class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
|
||||
|
||||
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
|
||||
"""Extractor for TSC_KSamplerAdvanced nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
@@ -461,7 +512,7 @@ class BasicSchedulerExtractor(NodeMetadataExtractor):
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||
class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
@@ -480,26 +531,8 @@ class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
@@ -569,6 +602,40 @@ class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
||||
|
||||
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Save the original conditioning inputs
|
||||
base_positive = inputs.get("base_positive")
|
||||
base_negative = inputs.get("base_negative")
|
||||
|
||||
if base_positive is not None or base_negative is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
|
||||
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Extract transformed conditionings from outputs
|
||||
# outputs structure: [(base_positive, base_negative, show_help, )]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 2:
|
||||
transformed_positive = first_output[0]
|
||||
transformed_negative = first_output[1]
|
||||
|
||||
# Save transformed conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
|
||||
|
||||
# Registry of node-specific extractors
|
||||
# Keys are node class names
|
||||
NODE_EXTRACTORS = {
|
||||
@@ -578,6 +645,10 @@ NODE_EXTRACTORS = {
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack
|
||||
"KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack
|
||||
# Sampling Selectors
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
@@ -594,6 +665,8 @@ NODE_EXTRACTORS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import logging
|
||||
from nodes import LoraLoader
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import asyncio
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraManagerLoader:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@@ -37,19 +37,39 @@ class LoraManagerLoader:
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the provided path and strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
@@ -66,13 +86,19 @@ class LoraManagerLoader:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the resolved path with separate strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -43,7 +42,7 @@ class LoraStacker:
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
@@ -58,7 +57,7 @@ class LoraStacker:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Add to stack without loading
|
||||
# replace '/' with os.sep to avoid different OS path format
|
||||
|
||||
@@ -4,8 +4,7 @@ import asyncio
|
||||
import re
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
@@ -71,25 +70,20 @@ class SaveImage:
|
||||
FUNCTION = "process_image"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
async def get_lora_hash(self, lora_name):
|
||||
def get_lora_hash(self, lora_name):
|
||||
"""Get the lora hash from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||
|
||||
# Use the new direct filename lookup method
|
||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
return item.get('sha256')
|
||||
return None
|
||||
|
||||
async def get_checkpoint_hash(self, checkpoint_path):
|
||||
def get_checkpoint_hash(self, checkpoint_path):
|
||||
"""Get the checkpoint hash from cache"""
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
@@ -102,18 +96,10 @@ class SaveImage:
|
||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
normalized_path = checkpoint_path.replace('\\', '/')
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
||||
return item.get('sha256')
|
||||
|
||||
return None
|
||||
|
||||
async def format_metadata(self, metadata_dict):
|
||||
def format_metadata(self, metadata_dict):
|
||||
"""Format metadata in the requested format similar to userComment example"""
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
@@ -140,7 +126,7 @@ class SaveImage:
|
||||
|
||||
# Get hash for each lora
|
||||
for lora_name, strength in lora_matches:
|
||||
hash_value = await self.get_lora_hash(lora_name)
|
||||
hash_value = self.get_lora_hash(lora_name)
|
||||
if hash_value:
|
||||
lora_hashes[lora_name] = hash_value
|
||||
else:
|
||||
@@ -226,7 +212,7 @@ class SaveImage:
|
||||
checkpoint = metadata_dict.get('checkpoint')
|
||||
if checkpoint is not None:
|
||||
# Get model hash
|
||||
model_hash = await self.get_checkpoint_hash(checkpoint)
|
||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||
|
||||
# Extract basename without path
|
||||
checkpoint_name = os.path.basename(checkpoint)
|
||||
@@ -329,8 +315,7 @@ class SaveImage:
|
||||
raw_metadata = get_metadata()
|
||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||
|
||||
# Get or create metadata asynchronously
|
||||
metadata = asyncio.run(self.format_metadata(metadata_dict))
|
||||
metadata = self.format_metadata(metadata_dict)
|
||||
|
||||
# Process filename_prefix with pattern substitution
|
||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||
|
||||
@@ -35,31 +35,11 @@ any_type = AnyType("*")
|
||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import copy
|
||||
import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_lora_info(lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
@@ -81,4 +61,70 @@ def get_loras_list(kwargs):
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
return []
|
||||
|
||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
if filter_prefix and not k.startswith(filter_prefix):
|
||||
continue
|
||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
def to_diffusers(input_lora):
|
||||
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
|
||||
if isinstance(input_lora, str):
|
||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||
else:
|
||||
tensors = {k: v for k, v in input_lora.items()}
|
||||
|
||||
# Convert FP8 tensors to BF16
|
||||
for k, v in tensors.items():
|
||||
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensors[k] = v.to(torch.bfloat16)
|
||||
|
||||
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||
|
||||
return new_tensors
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
"""Load a Flux LoRA for Nunchaku model"""
|
||||
model_wrapper = model.model.diffusion_model
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
|
||||
# Get full path to the LoRA file
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
# Convert the LoRA to diffusers format
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
# Handle embedding adjustment if needed
|
||||
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||
assert new_in_channels % 4 == 0
|
||||
new_in_channels = new_in_channels // 4
|
||||
|
||||
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||
if old_in_channels < new_in_channels:
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
92
py/nodes/wanvideo_lora_select.py
Normal file
92
py/nodes/wanvideo_lora_select.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WanVideoLoraSelect:
|
||||
NAME = "WanVideo Lora Select (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras"
|
||||
|
||||
def process_loras(self, text, low_mem_load=False, **kwargs):
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
# Process existing prev_lora if available
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
# Get blocks if available
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_from_widget = get_loras_list(kwargs)
|
||||
for lora in loras_from_widget:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Create lora item for WanVideo format
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
}
|
||||
|
||||
# Add to list and collect active loras
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format trigger_words for output
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format active_loras for output
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
@@ -19,7 +19,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
||||
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
||||
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([a-zA-Z0-9_\.\-]+):([0-9.]+)>'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
|
||||
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
||||
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
||||
|
||||
|
||||
@@ -50,6 +50,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
'from_civitai_image': True
|
||||
}
|
||||
|
||||
# Track already added LoRAs to prevent duplicates
|
||||
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
if "prompt" in metadata:
|
||||
result["gen_params"]["prompt"] = metadata["prompt"]
|
||||
@@ -96,11 +99,17 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
for resource in metadata["resources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type", "lora") == "lora":
|
||||
lora_hash = resource.get("hash", "")
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': resource.get("name", "Unknown LoRA"),
|
||||
'type': "lora",
|
||||
'weight': float(resource.get("weight", 1.0)),
|
||||
'hash': resource.get("hash", ""),
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("name", "Unknown"),
|
||||
@@ -114,7 +123,6 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and civitai_client:
|
||||
try:
|
||||
lora_hash = lora_entry['hash']
|
||||
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
@@ -129,43 +137,124 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process civitaiResources array
|
||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
||||
for resource in metadata["civitaiResources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
||||
# Initialize lora entry with the same structure as in automatic.py
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
# Skip resources that aren't LoRAs or LyCORIS
|
||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||
continue
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if resource.get('modelVersionId') and civitai_client:
|
||||
try:
|
||||
version_id = str(resource.get('modelVersionId'))
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
continue
|
||||
# Get unique identifier for deduplication
|
||||
version_id = str(resource.get("modelVersionId", ""))
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id and version_id in added_loras:
|
||||
continue
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if version_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
continue
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
||||
|
||||
# Track this LoRA in our deduplication dict
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Skip resources that aren't LoRAs or LyCORIS
|
||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||
continue
|
||||
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
|
||||
# Extract ID from URN format if available
|
||||
version_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
version_id = parts[1]
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a version ID and civitai client, try to get more info
|
||||
if version_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info with the version ID
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
else:
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
@@ -177,65 +266,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {resource.get('modelVersionId')}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
|
||||
# Extract ID from URN format if available
|
||||
model_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
model_id = parts[1]
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a model ID and civitai client, try to get more info
|
||||
if model_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info with the model ID
|
||||
civitai_info, error = await civitai_client.get_model_version_info(model_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
else:
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {model_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||
if not result["base_model"] and base_model_counts:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
619
py/routes/base_model_routes.py
Normal file
619
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,619 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
|
||||
import jinja2
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.settings_manager import settings
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelRoutes(ABC):
|
||||
"""Base route controller for all model types"""
|
||||
|
||||
def __init__(self, service):
|
||||
"""Initialize the route controller
|
||||
|
||||
Args:
|
||||
service: Model service instance (LoraService, CheckpointService, etc.)
|
||||
"""
|
||||
self.service = service
|
||||
self.model_type = service.model_type
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
def setup_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup common routes for the model type
|
||||
|
||||
Args:
|
||||
app: aiohttp application
|
||||
prefix: URL prefix (e.g., 'loras', 'checkpoints')
|
||||
"""
|
||||
# Common model management routes
|
||||
app.router.add_get(f'/api/{prefix}', self.get_models)
|
||||
app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
|
||||
app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
|
||||
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
|
||||
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
|
||||
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
|
||||
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
|
||||
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
|
||||
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
|
||||
|
||||
# Common query routes
|
||||
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
|
||||
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
|
||||
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
|
||||
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
|
||||
app.router.add_get(f'/api/{prefix}/folders', self.get_folders)
|
||||
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
||||
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
||||
|
||||
# Common Download management
|
||||
app.router.add_post(f'/api/download-model', self.download_model)
|
||||
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
||||
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
||||
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
||||
|
||||
# CivitAI integration routes
|
||||
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
||||
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||
|
||||
# Add generic page route
|
||||
app.router.add_get(f'/{prefix}', self.handle_models_page)
|
||||
|
||||
# Setup model-specific routes
|
||||
self.setup_specific_routes(app, prefix)
|
||||
|
||||
@abstractmethod
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup model-specific routes - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
async def handle_models_page(self, request: web.Request) -> web.Response:
|
||||
"""
|
||||
Generic handler for model pages (e.g., /loras, /checkpoints).
|
||||
Subclasses should set self.template_env and template_name.
|
||||
"""
|
||||
try:
|
||||
# Check if the scanner is initializing
|
||||
is_initializing = (
|
||||
self.service.scanner._cache is None or
|
||||
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
|
||||
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
|
||||
)
|
||||
|
||||
template_name = getattr(self, "template_name", None)
|
||||
if not self.template_env or not template_name:
|
||||
return web.Response(text="Template environment or template name not set", status=500)
|
||||
|
||||
if is_initializing:
|
||||
rendered = self.template_env.get_template(template_name).render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
else:
|
||||
try:
|
||||
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
||||
rendered = self.template_env.get_template(template_name).render(
|
||||
folders=getattr(cache, "folders", []),
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
rendered = self.template_env.get_template(template_name).render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling models page: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading models page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_models(self, request: web.Request) -> web.Response:
|
||||
"""Get paginated model data"""
|
||||
try:
|
||||
# Parse common query parameters
|
||||
params = self._parse_common_params(request)
|
||||
|
||||
# Get data from service
|
||||
result = await self.service.get_paginated_data(**params)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [await self.service.format_response(item) for item in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
def _parse_common_params(self, request: web.Request) -> Dict:
|
||||
"""Parse common query parameters"""
|
||||
# Parse basic pagination and sorting
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
|
||||
# Parse filter arrays
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
|
||||
|
||||
# Parse search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Parse hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
return {
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'sort_by': sort_by,
|
||||
'folder': folder,
|
||||
'search': search,
|
||||
'fuzzy_search': fuzzy_search,
|
||||
'base_models': base_models,
|
||||
'tags': tags,
|
||||
'search_options': search_options,
|
||||
'hash_filters': hash_filters,
|
||||
'favorites_only': favorites_only,
|
||||
# Add model-specific parameters
|
||||
**self._parse_specific_params(request)
|
||||
}
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse model-specific parameters - to be overridden by subclasses"""
|
||||
return {}
|
||||
|
||||
# Common route handlers
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model exclusion request"""
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
|
||||
|
||||
# If successful, format the metadata before returning
|
||||
if response.status == 200:
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
if data.get("success") and data.get("metadata"):
|
||||
formatted_metadata = await self.service.format_response(data["metadata"])
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"metadata": formatted_metadata
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request"""
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates"""
|
||||
return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
|
||||
|
||||
async def rename_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a model file and its associated files"""
|
||||
return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
|
||||
|
||||
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk deletion of models"""
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
"""Handle verification of duplicate model hashes"""
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20
|
||||
|
||||
top_tags = await self.service.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in models"""
|
||||
try:
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20
|
||||
|
||||
base_models = await self.service.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_models(self, request: web.Request) -> web.Response:
|
||||
"""Force a rescan of model files"""
|
||||
try:
|
||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||
|
||||
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"{self.model_type.capitalize()} scan completed"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_model_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the model root directories"""
|
||||
try:
|
||||
roots = self.service.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_folders(self, request: web.Request) -> web.Response:
|
||||
"""Get all folders in the cache"""
|
||||
try:
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
return web.json_response({
|
||||
'folders': cache.folders
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting folders: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_duplicate_models(self, request: web.Request) -> web.Response:
|
||||
"""Find models with duplicate SHA256 hashes"""
|
||||
try:
|
||||
# Get duplicate hashes from service
|
||||
duplicates = self.service.find_duplicate_hashes()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
|
||||
for sha256, paths in duplicates.items():
|
||||
group = {
|
||||
"hash": sha256,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(await self.service.format_response(model))
|
||||
|
||||
# Add the primary model too
|
||||
primary_path = self.service.get_path_by_hash(sha256)
|
||||
if primary_path and primary_path not in paths:
|
||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||
if primary_model:
|
||||
group["models"].insert(0, await self.service.format_response(primary_model))
|
||||
|
||||
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"duplicates": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
"""Find models with conflicting filenames"""
|
||||
try:
|
||||
# Get duplicate filenames from service
|
||||
duplicates = self.service.find_duplicate_filenames()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
|
||||
for filename, paths in duplicates.items():
|
||||
group = {
|
||||
"filename": filename,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(await self.service.format_response(model))
|
||||
|
||||
# Find the model from the main index too
|
||||
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self.service.get_path_by_hash(hash_val)
|
||||
if main_path and main_path not in paths:
|
||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||
if main_model:
|
||||
group["models"].insert(0, await self.service.format_response(main_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"conflicts": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
# Download management methods
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model download request"""
|
||||
return await ModelRouteUtils.handle_download_model(request)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
"""Handle model download request via GET method"""
|
||||
try:
|
||||
# Extract query parameters
|
||||
model_id = request.query.get('model_id')
|
||||
if not model_id:
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide 'model_id'"
|
||||
)
|
||||
|
||||
# Get optional parameters
|
||||
model_version_id = request.query.get('model_version_id')
|
||||
download_id = request.query.get('download_id')
|
||||
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
||||
|
||||
# Create a data dictionary that mimics what would be received from a POST request
|
||||
data = {
|
||||
'model_id': model_id
|
||||
}
|
||||
|
||||
# Add optional parameters only if they are provided
|
||||
if model_version_id:
|
||||
data['model_version_id'] = model_version_id
|
||||
|
||||
if download_id:
|
||||
data['download_id'] = download_id
|
||||
|
||||
data['use_default_paths'] = use_default_paths
|
||||
|
||||
# Create a mock request object with the data
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type('MockRequest', (), {
|
||||
'json': lambda self=None: future
|
||||
})()
|
||||
|
||||
# Call the existing download handler
|
||||
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET request for cancelling a download by download_id"""
|
||||
try:
|
||||
download_id = request.query.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Create a mock request with match_info for compatibility
|
||||
mock_request = type('MockRequest', (), {
|
||||
'match_info': {'download_id': download_id}
|
||||
})()
|
||||
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for download progress by download_id"""
|
||||
try:
|
||||
# Get download_id from URL path
|
||||
download_id = request.match_info.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID is required'
|
||||
}, status=400)
|
||||
|
||||
progress_data = ws_manager.get_download_progress(download_id)
|
||||
|
||||
if progress_data is None:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID not found'
|
||||
}, status=404)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'progress': progress_data.get('progress', 0)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all models in the background"""
|
||||
try:
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare models to process
|
||||
to_process = [
|
||||
model for model in cache.raw_data
|
||||
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each model
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=model['sha256'],
|
||||
file_path=model['file_path'],
|
||||
model_data=model,
|
||||
update_cache_func=self.service.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != model.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': model.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
# This will be implemented by subclasses as they need CivitAI client access
|
||||
return web.json_response({
|
||||
"error": "Not implemented in base class"
|
||||
}, status=501)
|
||||
105
py/routes/checkpoint_routes.py
Normal file
105
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Checkpoint-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "checkpoints.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Checkpoint routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||
super().setup_routes(app, 'checkpoints')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup Checkpoint-specific routes"""
|
||||
# Checkpoint-specific CivitAI integration
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
||||
|
||||
# Checkpoint info by name
|
||||
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
||||
|
||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if model_type.lower() != 'checkpoint':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
@@ -1,843 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointsRoutes:
|
||||
"""API routes for checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
|
||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
|
||||
|
||||
# Add new WebSocket endpoint for checkpoint progress
|
||||
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
|
||||
|
||||
# Add new routes for finding duplicates and filename conflicts
|
||||
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
|
||||
|
||||
# Add new endpoint for bulk deleting checkpoints
|
||||
app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints)
|
||||
|
||||
# Add new endpoint for verifying duplicates
|
||||
app.router.add_post('/api/checkpoints/verify-duplicates', self.verify_duplicates)
|
||||
|
||||
async def get_checkpoints(self, request):
|
||||
"""Get paginated checkpoint data"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
|
||||
|
||||
# Process search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Process hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Get data from scanner
|
||||
result = await self.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy_search=fuzzy_search,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
search_options=search_options,
|
||||
hash_filters=hash_filters,
|
||||
favorites_only=favorites_only # Pass favorites_only parameter
|
||||
)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
# Return as JSON
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||
folder=None, search=None, fuzzy_search=False,
|
||||
base_models=None, tags=None,
|
||||
search_options=None, hash_filters=None,
|
||||
favorites_only=False): # Add favorites_only parameter with default False
|
||||
"""Get paginated and filtered checkpoint data"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if any(tag in cp.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
|
||||
for cp in filtered_data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('file_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('file_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('model_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('model_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in cp:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _format_checkpoint_response(self, checkpoint):
|
||||
"""Format checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint["model_name"],
|
||||
"file_name": checkpoint["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint.get("base_model", ""),
|
||||
"folder": checkpoint["folder"],
|
||||
"sha256": checkpoint.get("sha256", ""),
|
||||
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint.get("size", 0),
|
||||
"modified": checkpoint.get("modified", ""),
|
||||
"tags": checkpoint.get("tags", []),
|
||||
"modelDescription": checkpoint.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint.get("from_civitai", True),
|
||||
"notes": checkpoint.get("notes", ""),
|
||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare checkpoints to process
|
||||
to_process = [
|
||||
cp for cp in cache.raw_data
|
||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each checkpoint
|
||||
for cp in to_process:
|
||||
try:
|
||||
original_name = cp.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=cp['sha256'],
|
||||
file_path=cp['file_path'],
|
||||
model_data=cp,
|
||||
update_cache_func=self.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != cp.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': cp.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get base models
|
||||
base_models = await self.scanner.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_checkpoints(self, request):
|
||||
"""Force a rescan of checkpoint files"""
|
||||
try:
|
||||
# Get the full_rebuild parameter and convert to bool, default to False
|
||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||
|
||||
await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild)
|
||||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoint_info(self, request):
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /checkpoints request"""
|
||||
try:
|
||||
# Check if the CheckpointScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
|
||||
logger.info("Checkpoints page is initializing, returning loading page")
|
||||
else:
|
||||
# 正常流程 - 获取已经初始化好的缓存数据
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||
# 如果获取缓存失败,也显示初始化页面
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Checkpoints cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading checkpoints page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model exclusion request"""
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
||||
response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
# If successful, format the metadata before returning
|
||||
if response.status == 200:
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
if data.get("success") and data.get("metadata"):
|
||||
formatted_metadata = self._format_checkpoint_response(data["metadata"])
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"metadata": formatted_metadata
|
||||
})
|
||||
|
||||
# Otherwise, return the original response
|
||||
return response
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement for checkpoints"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint download request"""
|
||||
async with self._download_lock:
|
||||
# Get the download manager from service registry if not already initialized
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback that uses checkpoint-specific WebSocket
|
||||
async def progress_callback(progress):
|
||||
await ws_manager.broadcast_checkpoint_progress({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=data.get('checkpoint_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
progress_callback=progress_callback,
|
||||
model_type="checkpoint"
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
||||
)
|
||||
|
||||
logger.error(f"Error downloading checkpoint: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates for checkpoints"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Update metadata
|
||||
metadata.update(metadata_updates)
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get the civitai client from service registry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if (model_type.lower() != 'checkpoint'):
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with duplicate SHA256 hashes"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate hashes from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for sha256, paths in duplicates.items():
|
||||
group = {
|
||||
"hash": sha256,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Add the primary model too
|
||||
primary_path = self.scanner._hash_index.get_path(sha256)
|
||||
if primary_path and primary_path not in paths:
|
||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||
if primary_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(primary_model))
|
||||
|
||||
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"duplicates": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with conflicting filenames"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate filenames from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_filenames()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for filename, paths in duplicates.items():
|
||||
group = {
|
||||
"filename": filename,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Find the model from the main index too
|
||||
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self.scanner._hash_index.get_path(hash_val)
|
||||
if main_path and main_path not in paths:
|
||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||
if main_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(main_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"conflicts": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk deletion of checkpoint models"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request by model version ID for checkpoints"""
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
"""Handle verification of duplicate checkpoint hashes"""
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner)
|
||||
|
||||
async def rename_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a checkpoint file and its associated files"""
|
||||
return await ModelRouteUtils.handle_rename_model(request, self.scanner)
|
||||
105
py/routes/embedding_routes.py
Normal file
105
py/routes/embedding_routes.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Embedding-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Embedding routes with Embedding service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "embeddings.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.service = EmbeddingService(embedding_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Embedding routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'embeddings' prefix (includes page route)
|
||||
super().setup_routes(app, 'embeddings')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup Embedding-specific routes"""
|
||||
# Embedding-specific CivitAI integration
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding)
|
||||
|
||||
# Embedding info by name
|
||||
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info)
|
||||
|
||||
async def get_embedding_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific embedding by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
embedding_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if embedding_info:
|
||||
return web.json_response(embedding_info)
|
||||
else:
|
||||
return web.json_response({"error": "Embedding not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai embedding model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be TextualInversion (Embedding)
|
||||
if model_type.lower() not in ['textualinversion', 'embedding']:
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching embedding model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from ..utils.example_images_download_manager import DownloadManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
from ..utils.example_images_metadata import MetadataUpdater
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,188 +1,483 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import logging
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.scanner = None
|
||||
self.recipe_scanner = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "loras.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the LoraScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or self.scanner.is_initializing()
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info("Loras page is initializing, returning loading page")
|
||||
else:
|
||||
# Normal flow - get data from initialized cache
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
||||
try:
|
||||
# Return the file URL directly for the first lora root's preview
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
# If not in recipes dir, try to create a valid URL from the file path
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Register routes
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
||||
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
||||
app.router.add_get(f'/api/{prefix}/preview-url', self.get_lora_preview_url)
|
||||
app.router.add_get(f'/api/{prefix}/civitai-url', self.get_lora_civitai_url)
|
||||
app.router.add_get(f'/api/{prefix}/model-description', self.get_lora_model_description)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
# LoRA-specific management routes
|
||||
app.router.add_post(f'/api/{prefix}/move_model', self.move_model)
|
||||
app.router.add_post(f'/api/{prefix}/move_models_bulk', self.move_models_bulk)
|
||||
|
||||
# CivitAI integration with LoRA-specific validation
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
|
||||
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||
|
||||
# ComfyUI integration
|
||||
app.router.add_post(f'/api/{prefix}/get_trigger_words', self.get_trigger_words)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
return params
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
# CivitAI integration methods
|
||||
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai LoRA model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be LORA, LoCon, or DORA
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
if model_type.lower() not in VALID_LORA_TYPES:
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the model file (type="Model") in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LoRA model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID"""
|
||||
try:
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
|
||||
# Get model details from Civitai API
|
||||
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
|
||||
|
||||
if not model:
|
||||
# Log warning for failed model retrieval
|
||||
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||
|
||||
# Determine status code based on error message
|
||||
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
|
||||
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": error_msg or "Failed to fetch model information"
|
||||
}, status=status_code)
|
||||
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by hash"""
|
||||
try:
|
||||
hash = request.match_info.get('hash')
|
||||
model = await self.civitai_client.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details by hash: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
# Model management methods
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model move request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path') # full path of the model file
|
||||
target_path = data.get('target_path') # folder path to move the model to
|
||||
|
||||
if not file_path or not target_path:
|
||||
return web.Response(text='File path and target path are required', status=400)
|
||||
|
||||
# Check if source and destination are the same
|
||||
import os
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
|
||||
|
||||
# Check if target file already exists
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Target file already exists: {target_file_path}"
|
||||
}, status=409) # 409 Conflict
|
||||
|
||||
# Call scanner to handle the move operation
|
||||
success = await self.service.scanner.move_model(file_path, target_path)
|
||||
|
||||
if success:
|
||||
return web.json_response({'success': True, 'new_file_path': target_file_path})
|
||||
else:
|
||||
return web.Response(text='Failed to move model', status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk model move request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', []) # list of full paths of the model files
|
||||
target_path = data.get('target_path') # folder path to move the models to
|
||||
|
||||
if not file_paths or not target_path:
|
||||
return web.Response(text='File paths and target path are required', status=400)
|
||||
|
||||
results = []
|
||||
import os
|
||||
for file_path in file_paths:
|
||||
# Check if source and destination are the same
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": True,
|
||||
"message": "Source and target directories are the same"
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if target file already exists
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": False,
|
||||
"message": f"Target file already exists: {target_file_path}"
|
||||
})
|
||||
continue
|
||||
|
||||
# Try to move the model
|
||||
success = await self.service.scanner.move_model(file_path, target_path)
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": success,
|
||||
"message": "Success" if success else "Failed to move model"
|
||||
})
|
||||
|
||||
# Count successes and failures
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||
'results': results,
|
||||
'success_count': success_count,
|
||||
'failure_count': failure_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
|
||||
if not model_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Check if we already have the description stored in metadata
|
||||
description = None
|
||||
tags = []
|
||||
creator = {}
|
||||
if file_path:
|
||||
import os
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
description = metadata.get('modelDescription')
|
||||
tags = metadata.get('tags', [])
|
||||
creator = metadata.get('creator', {})
|
||||
|
||||
# If description is not in metadata, fetch from CivitAI
|
||||
if not description:
|
||||
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
||||
|
||||
if model_metadata:
|
||||
description = model_metadata.get('description')
|
||||
tags = model_metadata.get('tags', [])
|
||||
creator = model_metadata.get('creator', {})
|
||||
|
||||
# Save the metadata to file if we have a file path and got metadata
|
||||
if file_path:
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
metadata['modelDescription'] = description
|
||||
metadata['tags'] = tags
|
||||
# Ensure the civitai dict exists
|
||||
if 'civitai' not in metadata:
|
||||
metadata['civitai'] = {}
|
||||
# Store creator in the civitai nested structure
|
||||
metadata['civitai']['creator'] = creator
|
||||
|
||||
await MetadataManager.save_metadata(file_path, metadata, True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model metadata: {e}")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'description': description or "<p>No model description available.</p>",
|
||||
'tags': tags,
|
||||
'creator': creator
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model metadata: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -10,6 +11,7 @@ from ..utils.usage_stats import UsageStats
|
||||
from ..utils.lora_metadata import extract_trained_words
|
||||
from ..config import config
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -90,6 +92,8 @@ class MiscRoutes:
|
||||
# Add new route for clearing cache
|
||||
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
||||
|
||||
app.router.add_get('/api/health-check', lambda request: web.json_response({'status': 'ok'}))
|
||||
|
||||
# Usage stats routes
|
||||
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
||||
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
||||
@@ -106,6 +110,9 @@ class MiscRoutes:
|
||||
# Node registry endpoints
|
||||
app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
|
||||
app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
|
||||
|
||||
# Add new route for checking if a model exists in the library
|
||||
app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists)
|
||||
|
||||
@staticmethod
|
||||
async def clear_cache(request):
|
||||
@@ -173,6 +180,16 @@ class MiscRoutes:
|
||||
if old_path != value:
|
||||
logger.info(f"Example images path changed to {value} - server restart required")
|
||||
|
||||
# Special handling for base_model_path_mappings - parse JSON string
|
||||
if key == 'base_model_path_mappings' and value:
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Invalid JSON format for base_model_path_mappings: {value}"
|
||||
})
|
||||
|
||||
# Save to settings
|
||||
settings.set(key, value)
|
||||
|
||||
@@ -580,3 +597,111 @@ class MiscRoutes:
|
||||
'error': 'Internal Error',
|
||||
'message': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def check_model_exists(request):
|
||||
"""
|
||||
Check if a model with specified modelId and optionally modelVersionId exists in the library
|
||||
|
||||
Expects query parameters:
|
||||
- modelId: int - Civitai model ID (required)
|
||||
- modelVersionId: int - Civitai model version ID (optional)
|
||||
|
||||
Returns:
|
||||
- If modelVersionId is provided: JSON with a boolean 'exists' field
|
||||
- If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library
|
||||
"""
|
||||
try:
|
||||
# Get the modelId and modelVersionId from query parameters
|
||||
model_id_str = request.query.get('modelId')
|
||||
model_version_id_str = request.query.get('modelVersionId')
|
||||
|
||||
# Validate modelId parameter (required)
|
||||
if not model_id_str:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing required parameter: modelId'
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Convert modelId to integer
|
||||
model_id = int(model_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Get all scanners
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# If modelVersionId is provided, check for specific version
|
||||
if model_version_id_str:
|
||||
try:
|
||||
model_version_id = int(model_version_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelVersionId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Check lora scanner first
|
||||
exists = False
|
||||
model_type = None
|
||||
|
||||
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
||||
exists = True
|
||||
model_type = 'lora'
|
||||
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
||||
exists = True
|
||||
model_type = 'checkpoint'
|
||||
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id):
|
||||
exists = True
|
||||
model_type = 'embedding'
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'exists': exists,
|
||||
'modelType': model_type if exists else None
|
||||
})
|
||||
|
||||
# If modelVersionId is not provided, return all version IDs for the model
|
||||
else:
|
||||
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
|
||||
checkpoint_versions = []
|
||||
embedding_versions = []
|
||||
|
||||
# 优先lora,其次checkpoint,最后embedding
|
||||
if not lora_versions:
|
||||
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
|
||||
if not lora_versions and not checkpoint_versions:
|
||||
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
|
||||
|
||||
model_type = None
|
||||
versions = []
|
||||
|
||||
if lora_versions:
|
||||
model_type = 'lora'
|
||||
versions = lora_versions
|
||||
elif checkpoint_versions:
|
||||
model_type = 'checkpoint'
|
||||
versions = checkpoint_versions
|
||||
elif embedding_versions:
|
||||
model_type = 'embedding'
|
||||
versions = embedding_versions
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'modelId': model_id,
|
||||
'modelType': model_type,
|
||||
'versions': versions
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model existence: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import io
|
||||
import logging
|
||||
from aiohttp import web
|
||||
@@ -16,6 +16,7 @@ from ..utils.exif_utils import ExifUtils
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
|
||||
from ..services.settings_manager import settings
|
||||
from ..config import config
|
||||
|
||||
# Check if running in standalone mode
|
||||
@@ -40,7 +41,10 @@ class RecipeRoutes:
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.recipe_scanner = None
|
||||
self.civitai_client = None
|
||||
# Remove WorkflowParser instance
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
# Pre-warm the cache
|
||||
self._init_cache_task = None
|
||||
@@ -54,6 +58,8 @@ class RecipeRoutes:
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls()
|
||||
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
|
||||
|
||||
app.router.add_get('/api/recipes', routes.get_recipes)
|
||||
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
||||
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
||||
@@ -115,6 +121,46 @@ class RecipeRoutes:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_recipes(self, request: web.Request) -> web.Response:
|
||||
"""API endpoint for getting paginated recipes"""
|
||||
@@ -1018,6 +1064,8 @@ class RecipeRoutes:
|
||||
shape_info = tensor_image.shape
|
||||
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
|
||||
|
||||
import torch
|
||||
|
||||
# Convert tensor to numpy array
|
||||
if isinstance(tensor_image, torch.Tensor):
|
||||
image_np = tensor_image.cpu().numpy()
|
||||
@@ -1100,7 +1148,7 @@ class RecipeRoutes:
|
||||
for lora_name, lora_strength in lora_matches:
|
||||
try:
|
||||
# Get lora info from scanner
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
|
||||
|
||||
# Create lora entry
|
||||
lora_entry = {
|
||||
@@ -1119,7 +1167,7 @@ class RecipeRoutes:
|
||||
# Get base model from lora scanner for the available loras
|
||||
base_model_counts = {}
|
||||
for lora in loras_data:
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
|
||||
if lora_info and "base_model" in lora_info:
|
||||
base_model = lora_info["base_model"]
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
@@ -1209,7 +1257,7 @@ class RecipeRoutes:
|
||||
if lora.get("isDeleted", False):
|
||||
continue
|
||||
|
||||
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
|
||||
if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")):
|
||||
continue
|
||||
|
||||
# Get the strength
|
||||
@@ -1317,7 +1365,7 @@ class RecipeRoutes:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
# Find target LoRA by name
|
||||
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
|
||||
target_lora = await lora_scanner.get_model_info_by_name(target_name)
|
||||
if not target_lora:
|
||||
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
||||
|
||||
@@ -1429,9 +1477,9 @@ class RecipeRoutes:
|
||||
if 'loras' in recipe:
|
||||
for lora in recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
# Ensure file_url is set (needed by frontend)
|
||||
if 'file_path' in recipe:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import subprocess
|
||||
import aiohttp
|
||||
import logging
|
||||
import toml
|
||||
import subprocess
|
||||
import git
|
||||
from datetime import datetime
|
||||
from aiohttp import web
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,6 +19,7 @@ class UpdateRoutes:
|
||||
"""Register update check routes"""
|
||||
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
|
||||
app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
|
||||
|
||||
@staticmethod
|
||||
async def check_updates(request):
|
||||
@@ -25,6 +28,8 @@ class UpdateRoutes:
|
||||
Returns update status and version information
|
||||
"""
|
||||
try:
|
||||
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version()
|
||||
|
||||
@@ -32,13 +37,21 @@ class UpdateRoutes:
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
|
||||
# Fetch remote version from GitHub
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
if nightly:
|
||||
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||
else:
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
|
||||
# Compare versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
if nightly:
|
||||
# For nightly, compare commit hashes
|
||||
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||
else:
|
||||
# For stable, compare semantic versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -46,7 +59,8 @@ class UpdateRoutes:
|
||||
'latest_version': remote_version,
|
||||
'update_available': update_available,
|
||||
'changelog': changelog,
|
||||
'git_info': git_info
|
||||
'git_info': git_info,
|
||||
'nightly': nightly
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -55,7 +69,7 @@ class UpdateRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_version_info(request):
|
||||
"""
|
||||
@@ -84,6 +98,168 @@ class UpdateRoutes:
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def perform_update(request):
|
||||
"""
|
||||
Perform Git-based update to latest release tag or main branch
|
||||
"""
|
||||
try:
|
||||
# Parse request body
|
||||
body = await request.json() if request.has_body else {}
|
||||
nightly = body.get('nightly', False)
|
||||
|
||||
# Get current plugin directory
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
# Backup settings.json if it exists
|
||||
settings_path = os.path.join(plugin_root, 'settings.json')
|
||||
settings_backup = None
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings_backup = f.read()
|
||||
logger.info("Backed up settings.json")
|
||||
|
||||
# Perform Git update
|
||||
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||
|
||||
# Restore settings.json if we backed it up
|
||||
if settings_backup and success:
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
f.write(settings_backup)
|
||||
logger.info("Restored settings.json")
|
||||
|
||||
if success:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully updated to {new_version}',
|
||||
'new_version': new_version
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to complete Git update'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
Fetch latest commit from main branch
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
|
||||
# Use GitHub API to fetch the latest commit from main branch
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
|
||||
return "main", []
|
||||
|
||||
data = await response.json()
|
||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||
commit_message = data.get('commit', {}).get('message', '')
|
||||
|
||||
# Format as "main-{short_hash}"
|
||||
version = f"main-{commit_sha}"
|
||||
|
||||
# Use commit message as changelog
|
||||
changelog = [commit_message] if commit_message else []
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||
return "main", []
|
||||
|
||||
@staticmethod
|
||||
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||
"""
|
||||
Compare local commit hash with remote main branch
|
||||
"""
|
||||
try:
|
||||
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||
if local_hash == 'unknown':
|
||||
return True # Assume update available if we can't get local hash
|
||||
|
||||
# Extract remote hash from version string (format: "main-{hash}")
|
||||
if '-' in remote_version:
|
||||
remote_hash = remote_version.split('-')[-1]
|
||||
return local_hash != remote_hash
|
||||
|
||||
return True # Default to update available
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing nightly versions: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform Git-based update using GitPython
|
||||
|
||||
Args:
|
||||
plugin_root: Path to the plugin root directory
|
||||
nightly: Whether to update to main branch or latest release
|
||||
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
|
||||
# Fetch latest changes
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
|
||||
if nightly:
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
if main_branch not in [branch.name for branch in repo.branches]:
|
||||
# Create local main branch if it doesn't exist
|
||||
repo.create_head(main_branch, origin.refs.main)
|
||||
|
||||
repo.heads[main_branch].checkout()
|
||||
origin.pull(main_branch)
|
||||
|
||||
# Get new commit hash
|
||||
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||
|
||||
else:
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
if not tags:
|
||||
logger.error("No tags found in repository")
|
||||
return False, ""
|
||||
|
||||
latest_tag = tags[0]
|
||||
|
||||
# Checkout to latest tag
|
||||
repo.git.checkout(latest_tag.name)
|
||||
|
||||
new_version = latest_tag.name
|
||||
|
||||
logger.info(f"Successfully updated to {new_version}")
|
||||
return True, new_version
|
||||
|
||||
except git.exc.GitError as e:
|
||||
logger.error(f"Git error during update: {e}")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Git update: {e}")
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_local_version() -> str:
|
||||
"""Get local plugin version from pyproject.toml"""
|
||||
|
||||
259
py/services/base_model_service.py
Normal file
259
py/services/base_model_service.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type
|
||||
import logging
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from .settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.)
|
||||
scanner: Model scanner instance
|
||||
metadata_class: Metadata class for this model type
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, **kwargs) -> Dict:
|
||||
"""Get paginated and filtered model data
|
||||
|
||||
Args:
|
||||
page: Page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
||||
folder: Folder filter
|
||||
search: Search term
|
||||
fuzzy_search: Whether to use fuzzy search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Search options dict
|
||||
hash_filters: Hash filtering options
|
||||
favorites_only: Filter for favorites only
|
||||
**kwargs: Additional model-specific filters
|
||||
|
||||
Returns:
|
||||
Dict containing paginated results
|
||||
"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Parse sort_by into sort_key and order
|
||||
if ':' in sort_by:
|
||||
sort_key, order = sort_by.split(':', 1)
|
||||
sort_key = sort_key.strip()
|
||||
order = order.strip().lower()
|
||||
if order not in ('asc', 'desc'):
|
||||
order = 'asc'
|
||||
else:
|
||||
sort_key = sort_by.strip()
|
||||
order = 'asc'
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set using new sort logic
|
||||
filtered_data = await cache.get_sorted_data(sort_key, order)
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||
|
||||
# Jump to pagination for hash filters
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
# Apply common filters
|
||||
filtered_data = await self._apply_common_filters(
|
||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||
)
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data, search, fuzzy_search, search_options
|
||||
)
|
||||
|
||||
# Apply model-specific filters
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||
base_models: list = None, tags: list = None,
|
||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
data = [
|
||||
item for item in data
|
||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options and search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
search_results = []
|
||||
|
||||
for item in data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('file_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('file_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('model_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('model_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in item:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||
for tag in item['tags']):
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
return search_results
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
@@ -1,131 +1,26 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
self._checkpoint_roots = self._init_checkpoint_roots()
|
||||
self._initialized = True
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _init_checkpoint_roots(self) -> List[str]:
|
||||
"""Initialize checkpoint roots from ComfyUI settings"""
|
||||
# Get both checkpoint and diffusion_models paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
|
||||
|
||||
# Combine, normalize and deduplicate paths
|
||||
all_paths = set()
|
||||
for path in checkpoint_paths + diffusion_paths:
|
||||
if os.path.exists(path):
|
||||
norm_path = path.replace(os.sep, "/")
|
||||
all_paths.add(norm_path)
|
||||
|
||||
# Sort for consistent order
|
||||
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
|
||||
|
||||
return sorted_paths
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return self._checkpoint_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all checkpoint directories and return metadata"""
|
||||
all_checkpoints = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for root in self._checkpoint_roots:
|
||||
task = asyncio.create_task(self._scan_directory(root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
checkpoints = await task
|
||||
all_checkpoints.extend(checkpoints)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
||||
|
||||
return all_checkpoints
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a directory for checkpoint files"""
|
||||
checkpoints = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
# Check if file has supported extension
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, checkpoints)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return checkpoints
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
||||
"""Process a single checkpoint file and add to results"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
checkpoints.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
return config.base_models_roots
|
||||
51
py/services/checkpoint_service.py
Normal file
51
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"modelDescription": checkpoint_data.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,13 +1,11 @@
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from email.parser import Parser
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from urllib.parse import unquote
|
||||
from ..utils.models import LoraMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,14 +43,14 @@ class CivitaiClient:
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=3, # Further reduced from 5 to 3
|
||||
ttl_dns_cache=0, # Disabled DNS caching completely
|
||||
limit=8, # Increase from 3 to 8 for better parallelism
|
||||
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
trust_env = True # Allow using system environment proxy settings
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
||||
# Configure timeout parameters - increase read timeout for large files
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=120)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=trust_env,
|
||||
@@ -165,7 +163,7 @@ class CivitaiClient:
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 0.5:
|
||||
if progress_callback and total_size and time_diff >= 1.0:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
@@ -225,7 +223,7 @@ class CivitaiClient:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: str, version_id: str = "") -> Optional[Dict]:
|
||||
async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata
|
||||
|
||||
Args:
|
||||
@@ -237,6 +235,8 @@ class CivitaiClient:
|
||||
"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
|
||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
@@ -244,45 +244,28 @@ class CivitaiClient:
|
||||
data = await response.json()
|
||||
model_versions = data.get('modelVersions', [])
|
||||
|
||||
# Find matching version
|
||||
matched_version = None
|
||||
|
||||
if version_id:
|
||||
# If version_id provided, find exact match
|
||||
for version in model_versions:
|
||||
if str(version.get('id')) == str(version_id):
|
||||
matched_version = version
|
||||
break
|
||||
else:
|
||||
# If no version_id then use the first version
|
||||
matched_version = model_versions[0] if model_versions else None
|
||||
|
||||
# If no match found, return None
|
||||
if not matched_version:
|
||||
# Step 2: Determine the version_id to use
|
||||
target_version_id = version_id
|
||||
if target_version_id is None:
|
||||
target_version_id = model_versions[0].get('id')
|
||||
|
||||
# Step 3: Get detailed version info using the version_id
|
||||
headers = self._get_request_headers()
|
||||
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
# Build result with modified fields
|
||||
result = matched_version.copy() # Copy to avoid modifying original
|
||||
|
||||
# Replace index with modelId
|
||||
if 'index' in result:
|
||||
del result['index']
|
||||
result['modelId'] = model_id
|
||||
version = await response.json()
|
||||
|
||||
# Add model field with metadata from top level
|
||||
result['model'] = {
|
||||
"name": data.get("name"),
|
||||
"type": data.get("type"),
|
||||
"nsfw": data.get("nsfw", False),
|
||||
"poi": data.get("poi", False),
|
||||
"description": data.get("description"),
|
||||
"tags": data.get("tags", [])
|
||||
}
|
||||
# Step 4: Enrich version_info with model data
|
||||
# Add description and tags from model data
|
||||
version['model']['description'] = data.get("description")
|
||||
version['model']['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator field from top level
|
||||
result['creator'] = data.get("creator")
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
|
||||
return result
|
||||
return version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -33,6 +35,10 @@ class DownloadManager:
|
||||
self._initialized = True
|
||||
|
||||
self._civitai_client = None # Will be lazily initialized
|
||||
# Add download management
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
@@ -47,55 +53,210 @@ class DownloadManager:
|
||||
async def _get_checkpoint_scanner(self):
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
relative_path: str = '', progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Download model from Civitai
|
||||
|
||||
async def download_from_civitai(self, model_id: int, model_version_id: int,
|
||||
save_dir: str = None, relative_path: str = '',
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
download_id: str = None) -> Dict:
|
||||
"""Download model from Civitai with task tracking and concurrency control
|
||||
|
||||
Args:
|
||||
download_url: Direct download URL for the model
|
||||
model_hash: SHA256 hash of the model
|
||||
model_id: Civitai model ID
|
||||
model_version_id: Civitai model version ID
|
||||
save_dir: Directory to save the model to
|
||||
save_dir: Directory to save the model
|
||||
relative_path: Relative path within save_dir
|
||||
progress_callback: Callback function for progress updates
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
use_default_paths: Flag to use default paths
|
||||
download_id: Unique identifier for this download task
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
"""
|
||||
# Use provided download_id or generate new one
|
||||
task_id = download_id or str(uuid.uuid4())
|
||||
|
||||
# Register download task in tracking dict
|
||||
self._active_downloads[task_id] = {
|
||||
'model_id': model_id,
|
||||
'model_version_id': model_version_id,
|
||||
'progress': 0,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# Create tracking task
|
||||
download_task = asyncio.create_task(
|
||||
self._download_with_semaphore(
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths
|
||||
)
|
||||
)
|
||||
|
||||
# Store task for tracking and cancellation
|
||||
self._download_tasks[task_id] = download_task
|
||||
|
||||
try:
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Wait for download to complete
|
||||
result = await download_task
|
||||
result['download_id'] = task_id # Include download_id in result
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
|
||||
finally:
|
||||
# Clean up task reference
|
||||
if task_id in self._download_tasks:
|
||||
del self._download_tasks[task_id]
|
||||
|
||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||
save_dir: str, relative_path: str,
|
||||
progress_callback=None, use_default_paths: bool = False):
|
||||
"""Execute download with semaphore to limit concurrency"""
|
||||
# Update status to waiting
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'waiting'
|
||||
|
||||
# Wrap progress callback to track progress in active_downloads
|
||||
original_callback = progress_callback
|
||||
async def tracking_callback(progress):
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['progress'] = progress
|
||||
if original_callback:
|
||||
await original_callback(progress)
|
||||
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
async with self._download_semaphore:
|
||||
# Update status to downloading
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'downloading'
|
||||
|
||||
# Use original download implementation
|
||||
try:
|
||||
# Check for cancellation before starting
|
||||
if asyncio.current_task().cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
result = await self._execute_original_download(
|
||||
model_id, model_version_id, save_dir,
|
||||
relative_path, tracking_callback, use_default_paths,
|
||||
task_id
|
||||
)
|
||||
|
||||
# Update status based on result
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
||||
if not result['success']:
|
||||
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
||||
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'cancelled'
|
||||
logger.info(f"Download cancelled for task {task_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle other errors
|
||||
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'failed'
|
||||
self._active_downloads[task_id]['error'] = str(e)
|
||||
return {'success': False, 'error': str(e)}
|
||||
finally:
|
||||
# Schedule cleanup of download record after delay
|
||||
asyncio.create_task(self._cleanup_download_record(task_id))
|
||||
|
||||
async def _cleanup_download_record(self, task_id: str):
|
||||
"""Keep completed downloads in history for a short time"""
|
||||
await asyncio.sleep(600) # Keep for 10 minutes
|
||||
if task_id in self._active_downloads:
|
||||
del self._active_downloads[task_id]
|
||||
|
||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths,
|
||||
download_id=None):
|
||||
"""Wrapper for original download_from_civitai implementation"""
|
||||
try:
|
||||
# Check if model version already exists in library
|
||||
if model_version_id is not None:
|
||||
# Check both scanners
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
|
||||
# Check lora scanner first
|
||||
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
|
||||
# Check checkpoint scanner
|
||||
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
|
||||
# Get civitai client
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = None
|
||||
error_msg = None
|
||||
|
||||
if model_hash:
|
||||
# Get model by hash
|
||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
||||
elif model_version_id:
|
||||
# Use model version ID directly
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
||||
elif download_url:
|
||||
# Extract version ID from download URL
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
version_info = await civitai_client.get_model_version(model_id, model_version_id)
|
||||
|
||||
if not version_info:
|
||||
if error_msg and "model not found" in error_msg.lower():
|
||||
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
|
||||
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||
if model_type_from_info == 'checkpoint':
|
||||
model_type = 'checkpoint'
|
||||
elif model_type_from_info in VALID_LORA_TYPES:
|
||||
model_type = 'lora'
|
||||
elif model_type_from_info == 'textualinversion':
|
||||
model_type = 'embedding'
|
||||
else:
|
||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||
|
||||
# Case 2: model_version_id was None, check after getting version_info
|
||||
if model_version_id is None:
|
||||
version_model_id = version_info.get('modelId')
|
||||
version_id = version_info.get('id')
|
||||
|
||||
if model_type == 'lora':
|
||||
# Check lora scanner
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
elif model_type == 'checkpoint':
|
||||
# Check checkpoint scanner
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
elif model_type == 'embedding':
|
||||
# Embeddings are not checked in scanners, but we can still check if it exists
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
if await embedding_scanner.check_model_version_exists(version_model_id, version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
# Set save_dir based on model type
|
||||
if model_type == 'checkpoint':
|
||||
default_path = settings.get('default_checkpoint_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'lora':
|
||||
default_path = settings.get('default_lora_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'embedding':
|
||||
default_path = settings.get('default_embedding_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
||||
save_dir = default_path
|
||||
|
||||
# Calculate relative path using template
|
||||
relative_path = self._calculate_relative_path(version_info)
|
||||
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Check if this is an early access model
|
||||
if version_info.get('earlyAccessEndsAt'):
|
||||
@@ -133,21 +294,12 @@ class DownloadManager:
|
||||
if model_type == "checkpoint":
|
||||
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
||||
else:
|
||||
elif model_type == "lora":
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||
|
||||
# 5.1 Get and update model tags, description and creator info
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
if model_metadata.get("tags"):
|
||||
metadata.tags = model_metadata.get("tags", [])
|
||||
if model_metadata.get("description"):
|
||||
metadata.modelDescription = model_metadata.get("description", "")
|
||||
if model_metadata.get("creator"):
|
||||
metadata.civitai["creator"] = model_metadata.get("creator")
|
||||
elif model_type == "embedding":
|
||||
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
@@ -157,7 +309,8 @@ class DownloadManager:
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
model_type=model_type,
|
||||
download_id=download_id
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -170,15 +323,63 @@ class DownloadManager:
|
||||
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _calculate_relative_path(self, version_info: Dict) -> str:
|
||||
"""Calculate relative path using template from settings
|
||||
|
||||
Args:
|
||||
version_info: Version info from Civitai API
|
||||
|
||||
Returns:
|
||||
Relative path string
|
||||
"""
|
||||
# Get path template from settings, default to '{base_model}/{first_tag}'
|
||||
path_template = settings.get('download_path_template', '{base_model}/{first_tag}')
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
return ''
|
||||
|
||||
# Get base model name
|
||||
base_model = version_info.get('baseModel', '')
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Get model tags
|
||||
model_tags = version_info.get('model', {}).get('tags', [])
|
||||
|
||||
# Find the first Civitai model tag that exists in model_tags
|
||||
first_tag = ''
|
||||
for civitai_tag in CIVITAI_MODEL_TAGS:
|
||||
if civitai_tag in model_tags:
|
||||
first_tag = civitai_tag
|
||||
break
|
||||
|
||||
# If no Civitai model tag found, fallback to first tag
|
||||
if not first_tag and model_tags:
|
||||
first_tag = model_tags[0]
|
||||
|
||||
# Format the template with available data
|
||||
formatted_path = path_template
|
||||
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
||||
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
||||
|
||||
return formatted_path
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
civitai_client = await self._get_civitai_client()
|
||||
save_path = metadata.file_path
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
# Store file path in active_downloads for potential cleanup
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['file_path'] = save_path
|
||||
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
@@ -261,9 +462,12 @@ class DownloadManager:
|
||||
if model_type == "checkpoint":
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
else:
|
||||
elif model_type == "lora":
|
||||
scanner = await self._get_lora_scanner()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
elif model_type == "embedding":
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
logger.info(f"Updating embedding cache for {save_path}")
|
||||
|
||||
# Convert metadata to dictionary
|
||||
metadata_dict = metadata.to_dict()
|
||||
@@ -297,4 +501,86 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
# Scale file progress to 3-100 range (after preview download)
|
||||
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
||||
await progress_callback(round(overall_progress))
|
||||
await progress_callback(round(overall_progress))
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict:
|
||||
"""Cancel an active download by download_id
|
||||
|
||||
Args:
|
||||
download_id: The unique identifier of the download task
|
||||
|
||||
Returns:
|
||||
Dict: Status of the cancellation operation
|
||||
"""
|
||||
if download_id not in self._download_tasks:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
try:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
task.cancel()
|
||||
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||
|
||||
# Wait briefly for the task to acknowledge cancellation
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
# Clean up partial downloads
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info and 'file_path' in download_info:
|
||||
# Delete the partial file
|
||||
file_path = download_info['file_path']
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
logger.debug(f"Deleted partial download: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting partial file: {e}")
|
||||
|
||||
# Delete metadata file if exists
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
os.unlink(metadata_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting metadata file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4)
|
||||
for preview_ext in ['.webp', '.mp4']:
|
||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||
if os.path.exists(preview_path):
|
||||
try:
|
||||
os.unlink(preview_path)
|
||||
logger.debug(f"Deleted preview file: {preview_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting preview file: {e}")
|
||||
|
||||
return {'success': True, 'message': 'Download cancelled successfully'}
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def get_active_downloads(self) -> Dict:
|
||||
"""Get information about all active downloads
|
||||
|
||||
Returns:
|
||||
Dict: List of active downloads and their status
|
||||
"""
|
||||
return {
|
||||
'downloads': [
|
||||
{
|
||||
'download_id': task_id,
|
||||
'model_id': info.get('model_id'),
|
||||
'model_version_id': info.get('model_version_id'),
|
||||
'progress': info.get('progress', 0),
|
||||
'status': info.get('status', 'unknown'),
|
||||
'error': info.get('error', None)
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
}
|
||||
26
py/services/embedding_scanner.py
Normal file
26
py/services/embedding_scanner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingScanner(ModelScanner):
|
||||
"""Service for scanning and managing embedding files"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
super().__init__(
|
||||
model_type="embedding",
|
||||
model_class=EmbeddingMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get embedding root directories"""
|
||||
return config.embeddings_roots
|
||||
51
py/services/embedding_service.py
Normal file
51
py/services/embedding_service.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingService(BaseModelService):
|
||||
"""Embedding-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Embedding service
|
||||
|
||||
Args:
|
||||
scanner: Embedding scanner instance
|
||||
"""
|
||||
super().__init__("embedding", scanner, EmbeddingMetadata)
|
||||
|
||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||
"""Format Embedding data for API response"""
|
||||
return {
|
||||
"model_name": embedding_data["model_name"],
|
||||
"file_name": embedding_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
||||
"base_model": embedding_data.get("base_model", ""),
|
||||
"folder": embedding_data["folder"],
|
||||
"sha256": embedding_data.get("sha256", ""),
|
||||
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": embedding_data.get("size", 0),
|
||||
"modified": embedding_data.get("modified", ""),
|
||||
"tags": embedding_data.get("tags", []),
|
||||
"modelDescription": embedding_data.get("modelDescription", ""),
|
||||
"from_civitai": embedding_data.get("from_civitai", True),
|
||||
"notes": embedding_data.get("notes", ""),
|
||||
"model_type": embedding_data.get("model_type", "embedding"),
|
||||
"favorite": embedding_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Embeddings with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Embeddings with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,65 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
@dataclass
|
||||
class LoraCache:
|
||||
"""Cache structure for LoRA data"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview_url for a specific lora in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the lora wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
else:
|
||||
return False # Lora not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
@@ -1,20 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
import re
|
||||
from typing import List, Dict, Optional, Set
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .service_registry import ServiceRegistry
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,404 +12,21 @@ logger = logging.getLogger(__name__)
|
||||
class LoraScanner(ModelScanner):
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Ensure initialization happens only once
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get lora root directories"""
|
||||
return config.loras_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all LoRA directories and return metadata"""
|
||||
all_loras = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for lora_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(lora_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
loras = await task
|
||||
all_loras.extend(loras)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_loras
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for LoRA files"""
|
||||
loras = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
# Use original path instead of real path
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, loras)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return loras
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
loras.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, first_letter: str = None) -> Dict:
|
||||
"""Get paginated and filtered lora data
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort method ('name' or 'date')
|
||||
folder: Filter by folder path
|
||||
search: Search term
|
||||
fuzzy_search: Use fuzzy matching for search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
||||
favorites_only: Filter for favorite models only
|
||||
first_letter: Filter by first letter of model name
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply first letter filtering
|
||||
if first_letter:
|
||||
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if any(tag in lora.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
search_opts = search_options or {}
|
||||
|
||||
for lora in filtered_data:
|
||||
# Search by file name
|
||||
if search_opts.get('filename', True):
|
||||
if fuzzy_match(lora.get('file_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_opts.get('modelname', True):
|
||||
if fuzzy_match(lora.get('model_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_opts.get('tags', False) and 'tags' in lora:
|
||||
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _filter_by_first_letter(self, data, letter):
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char):
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
async def get_letter_counts(self):
|
||||
"""Get count of models for each letter of the alphabet"""
|
||||
cache = await self.get_cached_data()
|
||||
data = cache.sorted_by_name
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
# Lora-specific hash index functionality
|
||||
def has_lora_hash(self, sha256: str) -> bool:
|
||||
"""Check if a LoRA with given hash exists"""
|
||||
return self.has_hash(sha256)
|
||||
|
||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a LoRA by its hash"""
|
||||
return self.get_path_by_hash(sha256)
|
||||
|
||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self.get_hash_by_path(file_path)
|
||||
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count"""
|
||||
# Make sure cache is initialized
|
||||
await self.get_cached_data()
|
||||
|
||||
# Sort tags by count in descending order
|
||||
sorted_tags = sorted(
|
||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||
key=lambda x: x['count'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Return limited number
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get base models used in loras sorted by frequency"""
|
||||
# Make sure cache is initialized
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Count base model occurrences
|
||||
base_model_counts = {}
|
||||
for lora in cache.raw_data:
|
||||
if 'base_model' in lora and lora['base_model']:
|
||||
base_model = lora['base_model']
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
# Sort base models by count
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
# Return limited number
|
||||
return sorted_models[:limit]
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
@@ -456,19 +63,3 @@ class LoraScanner(ModelScanner):
|
||||
test_hash_result = self._hash_index.get_hash(test_path)
|
||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||
|
||||
async def get_lora_info_by_name(self, name):
|
||||
"""Get LoRA information by name"""
|
||||
try:
|
||||
# Get cached data
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Find the LoRA by name
|
||||
for lora in cache.raw_data:
|
||||
if lora.get("file_name") == name:
|
||||
return lora
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
212
py/services/lora_service.py
Normal file
212
py/services/lora_service.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize LoRA service
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
"modelDescription": lora_data.get("modelDescription", ""),
|
||||
"from_civitai": lora_data.get("from_civitai", True),
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
data = cache.raw_data
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
return lora.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
preview_url = lora.get('preview_url')
|
||||
if preview_url:
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,37 +1,85 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
# Supported sort modes: (sort_key, order)
|
||||
# order: 'asc' for ascending, 'desc' for descending
|
||||
SUPPORTED_SORT_MODES = [
|
||||
('name', 'asc'),
|
||||
('name', 'desc'),
|
||||
('date', 'asc'),
|
||||
('date', 'desc'),
|
||||
('size', 'asc'),
|
||||
('size', 'desc'),
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data"""
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
self._last_sort: Tuple[str, str] = (None, None)
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
if self._last_sort != (None, None):
|
||||
sort_key, order = self._last_sort
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
# Update folder list
|
||||
# else: do nothing
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by model_name, case-insensitive
|
||||
return natsorted(
|
||||
data,
|
||||
key=lambda x: x['model_name'].lower(),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
reverse=reverse
|
||||
)
|
||||
else:
|
||||
# Fallback: no sort
|
||||
return list(data)
|
||||
|
||||
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||
async with self._lock:
|
||||
if (sort_key, order) == self._last_sort:
|
||||
return self._last_sorted_data
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sort = (sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
return sorted_data
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||
"""Update preview_url for a specific model in all cached data
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Type, Set
|
||||
import msgpack # Add MessagePack import for efficient serialization
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
@@ -19,17 +18,33 @@ from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define cache version to handle future format changes
|
||||
# Version history:
|
||||
# 1 - Initial version
|
||||
# 2 - Added duplicate_filenames and duplicate_hashes tracking
|
||||
# 3 - Added _excluded_models list to cache
|
||||
CACHE_VERSION = 3
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
_instances = {} # Dictionary to store instances by class
|
||||
_locks = {} # Dictionary to store locks by class
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Implement singleton pattern for each subclass"""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__new__(cls)
|
||||
return cls._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def _get_lock(cls):
|
||||
"""Get or create a lock for this class"""
|
||||
if cls not in cls._locks:
|
||||
cls._locks[cls] = asyncio.Lock()
|
||||
return cls._locks[cls]
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
lock = cls._get_lock()
|
||||
async with lock:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = cls()
|
||||
return cls._instances[cls]
|
||||
|
||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||
"""Initialize the scanner
|
||||
@@ -40,6 +55,10 @@ class ModelScanner:
|
||||
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||
hash_index: Hash index instance (optional)
|
||||
"""
|
||||
# Ensure initialization happens only once per instance
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_class = model_class
|
||||
self.file_extensions = file_extensions
|
||||
@@ -48,202 +67,15 @@ class ModelScanner:
|
||||
self._tags_count = {} # Dictionary to store tag counts
|
||||
self._is_initializing = False # Flag to track initialization state
|
||||
self._excluded_models = [] # List to track excluded models
|
||||
self._dirs_last_modified = {} # Track directory modification times
|
||||
self._use_cache_files = False # Flag to control cache file usage, default to disabled
|
||||
|
||||
# Clear cache files if disabled
|
||||
if not self._use_cache_files:
|
||||
self._clear_cache_files()
|
||||
self._initialized = True
|
||||
|
||||
# Register this service
|
||||
asyncio.create_task(self._register_service())
|
||||
|
||||
def _clear_cache_files(self):
|
||||
"""Clear existing cache files if they exist"""
|
||||
try:
|
||||
cache_path = self._get_cache_file_path()
|
||||
if cache_path and os.path.exists(cache_path):
|
||||
os.remove(cache_path)
|
||||
logger.info(f"Cleared {self.model_type} cache file: {cache_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing {self.model_type} cache file: {e}")
|
||||
|
||||
async def _register_service(self):
|
||||
"""Register this instance with the ServiceRegistry"""
|
||||
service_name = f"{self.model_type}_scanner"
|
||||
await ServiceRegistry.register_service(service_name, self)
|
||||
|
||||
def _get_cache_file_path(self) -> Optional[str]:
|
||||
"""Get the path to the cache file"""
|
||||
# Get the directory where this module is located
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
# Create a cache directory within the project if it doesn't exist
|
||||
cache_dir = os.path.join(current_dir, "cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create filename based on model type
|
||||
cache_filename = f"lm_{self.model_type}_cache.msgpack"
|
||||
return os.path.join(cache_dir, cache_filename)
|
||||
|
||||
def _prepare_for_msgpack(self, data):
|
||||
"""Preprocess data to accommodate MessagePack serialization limitations
|
||||
|
||||
Converts integers exceeding safe range to strings
|
||||
|
||||
Args:
|
||||
data: Any type of data structure
|
||||
|
||||
Returns:
|
||||
Preprocessed data structure with large integers converted to strings
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: self._prepare_for_msgpack(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self._prepare_for_msgpack(item) for item in data]
|
||||
elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991):
|
||||
# Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings
|
||||
return str(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
async def _save_cache_to_disk(self) -> bool:
|
||||
"""Save cache data to disk using MessagePack"""
|
||||
if not self._use_cache_files:
|
||||
logger.debug(f"Cache files disabled for {self.model_type}, skipping save")
|
||||
return False
|
||||
|
||||
if self._cache is None or not self._cache.raw_data:
|
||||
logger.debug(f"No {self.model_type} cache data to save")
|
||||
return False
|
||||
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path:
|
||||
logger.warning(f"Cannot determine {self.model_type} cache file location")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create cache data structure
|
||||
cache_data = {
|
||||
"version": CACHE_VERSION,
|
||||
"timestamp": time.time(),
|
||||
"model_type": self.model_type,
|
||||
"raw_data": self._cache.raw_data,
|
||||
"hash_index": {
|
||||
"hash_to_path": self._hash_index._hash_to_path,
|
||||
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
|
||||
"duplicate_hashes": self._hash_index._duplicate_hashes,
|
||||
"duplicate_filenames": self._hash_index._duplicate_filenames
|
||||
},
|
||||
"tags_count": self._tags_count,
|
||||
"dirs_last_modified": self._get_dirs_last_modified(),
|
||||
"excluded_models": self._excluded_models # Add excluded_models to cache data
|
||||
}
|
||||
|
||||
# Preprocess data to handle large integers
|
||||
processed_cache_data = self._prepare_for_msgpack(cache_data)
|
||||
|
||||
# Write to temporary file first (atomic operation)
|
||||
temp_path = f"{cache_path}.tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
msgpack.pack(processed_cache_data, f)
|
||||
|
||||
# Replace the old file with the new one
|
||||
if os.path.exists(cache_path):
|
||||
os.replace(temp_path, cache_path)
|
||||
else:
|
||||
os.rename(temp_path, cache_path)
|
||||
|
||||
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
|
||||
logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
|
||||
# Try to clean up temp file if it exists
|
||||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_dirs_last_modified(self) -> Dict[str, float]:
|
||||
"""Get last modified time for all model directories"""
|
||||
dirs_info = {}
|
||||
for root in self.get_model_roots():
|
||||
if os.path.exists(root):
|
||||
dirs_info[root] = os.path.getmtime(root)
|
||||
# Also check immediate subdirectories for changes
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
if entry.is_dir(follow_symlinks=True):
|
||||
dirs_info[entry.path] = entry.stat().st_mtime
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting directory info for {root}: {e}")
|
||||
return dirs_info
|
||||
|
||||
def _is_cache_valid(self, cache_data: Dict) -> bool:
|
||||
"""Validate if the loaded cache is still valid"""
|
||||
if not cache_data or cache_data.get("version") != CACHE_VERSION:
|
||||
logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}")
|
||||
return False
|
||||
|
||||
if cache_data.get("model_type") != self.model_type:
|
||||
logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _load_cache_from_disk(self) -> bool:
|
||||
"""Load cache data from disk using MessagePack"""
|
||||
if not self._use_cache_files:
|
||||
logger.info(f"Cache files disabled for {self.model_type}, skipping load")
|
||||
return False
|
||||
|
||||
start_time = time.time()
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path or not os.path.exists(cache_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
cache_data = msgpack.unpack(f)
|
||||
|
||||
# Validate cache data
|
||||
if not self._is_cache_valid(cache_data):
|
||||
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
|
||||
return False
|
||||
|
||||
# Load data into memory
|
||||
self._cache = ModelCache(
|
||||
raw_data=cache_data["raw_data"],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Load hash index
|
||||
hash_index_data = cache_data.get("hash_index", {})
|
||||
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
|
||||
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
|
||||
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
|
||||
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
|
||||
|
||||
# Load tags count
|
||||
self._tags_count = cache_data.get("tags_count", {})
|
||||
|
||||
# Load excluded models
|
||||
self._excluded_models = cache_data.get("excluded_models", [])
|
||||
|
||||
# Resort the cache
|
||||
await self._cache.resort()
|
||||
|
||||
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
|
||||
return False
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
@@ -252,8 +84,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -271,21 +101,6 @@ class ModelScanner:
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
|
||||
if cache_loaded:
|
||||
# Cache loaded successfully, broadcast complete message
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
'progress': 100,
|
||||
'status': 'complete',
|
||||
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
self._is_initializing = False
|
||||
return
|
||||
|
||||
# If cache loading failed, proceed with full scan
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -332,9 +147,6 @@ class ModelScanner:
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||
|
||||
# Save the cache to disk after initialization
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
# Send completion message
|
||||
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -509,40 +321,21 @@ class ModelScanner:
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to refresh the cache
|
||||
rebuild_cache: Whether to completely rebuild the cache by reloading from disk first
|
||||
rebuild_cache: Whether to completely rebuild the cache
|
||||
"""
|
||||
# If cache is not initialized, return an empty cache
|
||||
# Actual initialization should be done via initialize_in_background
|
||||
if self._cache is None and not force_refresh:
|
||||
return ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
# If rebuild_cache is True, try to reload from disk before reconciliation
|
||||
if rebuild_cache:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...")
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
if cache_loaded:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk")
|
||||
else:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild")
|
||||
# If loading from disk failed, do a complete rebuild and save to disk
|
||||
await self._initialize_cache()
|
||||
await self._save_cache_to_disk()
|
||||
return self._cache
|
||||
|
||||
if self._cache is None:
|
||||
# For initial creation, do a full initialization
|
||||
await self._initialize_cache()
|
||||
# Save the newly built cache
|
||||
await self._save_cache_to_disk()
|
||||
else:
|
||||
# For subsequent refreshes, use fast reconciliation
|
||||
await self._reconcile_cache()
|
||||
|
||||
return self._cache
|
||||
@@ -577,8 +370,6 @@ class ModelScanner:
|
||||
# Update cache
|
||||
self._cache = ModelCache(
|
||||
raw_data=raw_data,
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -592,8 +383,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
finally:
|
||||
@@ -735,19 +524,74 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||
finally:
|
||||
self._is_initializing = False # Unset flag
|
||||
|
||||
# These methods should be implemented in child classes
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
raise NotImplementedError("Subclasses must implement scan_all_models")
|
||||
all_models = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for model_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(model_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
models = await task
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_models
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for model files"""
|
||||
models = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
# Use original path instead of real path
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, models)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return models
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, models: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
models.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
def is_initializing(self) -> bool:
|
||||
"""Check if the scanner is currently initializing"""
|
||||
@@ -931,7 +775,7 @@ class ModelScanner:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||
"""Add a model to the cache and save to disk
|
||||
"""Add a model to the cache
|
||||
|
||||
Args:
|
||||
metadata_dict: The model metadata dictionary
|
||||
@@ -960,9 +804,6 @@ class ModelScanner:
|
||||
|
||||
# Update the hash index
|
||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Save to disk
|
||||
await self._save_cache_to_disk()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding model to cache: {e}")
|
||||
@@ -1102,9 +943,6 @@ class ModelScanner:
|
||||
|
||||
await cache.resort()
|
||||
|
||||
# Save the updated cache
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
@@ -1198,11 +1036,7 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
return False
|
||||
|
||||
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
if updated:
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
return updated
|
||||
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
|
||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||
"""Delete multiple models and update cache in a batch operation
|
||||
@@ -1334,9 +1168,6 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -1362,3 +1193,59 @@ class ModelScanner:
|
||||
if file_name in self._hash_index._duplicate_filenames:
|
||||
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
||||
del self._hash_index._duplicate_filenames[file_name]
|
||||
|
||||
async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool:
|
||||
"""Check if a specific model version exists in the cache
|
||||
|
||||
Args:
|
||||
model_id: Civitai model ID
|
||||
model_version_id: Civitai model version ID
|
||||
|
||||
Returns:
|
||||
bool: True if the model version exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return False
|
||||
|
||||
for item in cache.raw_data:
|
||||
if (item.get('civitai') and
|
||||
item['civitai'].get('modelId') == model_id and
|
||||
item['civitai'].get('id') == model_version_id):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking model version existence: {e}")
|
||||
return False
|
||||
|
||||
async def get_model_versions_by_id(self, model_id: int) -> List[Dict]:
|
||||
"""Get all versions of a model by its ID
|
||||
|
||||
Args:
|
||||
model_id: Civitai model ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of version information dictionaries
|
||||
"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return []
|
||||
|
||||
versions = []
|
||||
for item in cache.raw_data:
|
||||
if (item.get('civitai') and
|
||||
item['civitai'].get('modelId') == model_id and
|
||||
item['civitai'].get('id')):
|
||||
versions.append({
|
||||
'versionId': item['civitai'].get('id'),
|
||||
'name': item['civitai'].get('name'),
|
||||
'fileName': item.get('file_name', '')
|
||||
})
|
||||
|
||||
return versions
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model versions: {e}")
|
||||
return []
|
||||
|
||||
142
py/services/model_service_factory.py
Normal file
142
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Dict, Type, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelServiceFactory:
|
||||
"""Factory for managing model services and routes"""
|
||||
|
||||
_services: Dict[str, Type] = {}
|
||||
_routes: Dict[str, Type] = {}
|
||||
_initialized_services: Dict[str, Any] = {}
|
||||
_initialized_routes: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||
"""Register a new model type with its service and route classes
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||
service_class: The service class for this model type
|
||||
route_class: The route class for this model type
|
||||
"""
|
||||
cls._services[model_type] = service_class
|
||||
cls._routes[model_type] = route_class
|
||||
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_service_class(cls, model_type: str) -> Type:
|
||||
"""Get service class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The service class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._services:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._services[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_class(cls, model_type: str) -> Type:
|
||||
"""Get route class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._routes:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_instance(cls, model_type: str):
|
||||
"""Get or create route instance for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route instance for the model type
|
||||
"""
|
||||
if model_type not in cls._initialized_routes:
|
||||
route_class = cls.get_route_class(model_type)
|
||||
cls._initialized_routes[model_type] = route_class()
|
||||
return cls._initialized_routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def setup_all_routes(cls, app):
|
||||
"""Setup routes for all registered model types
|
||||
|
||||
Args:
|
||||
app: The aiohttp application instance
|
||||
"""
|
||||
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||
|
||||
for model_type in cls._services.keys():
|
||||
try:
|
||||
routes_instance = cls.get_route_instance(model_type)
|
||||
routes_instance.setup_routes(app)
|
||||
logger.info(f"Successfully set up routes for {model_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def get_registered_types(cls) -> list:
|
||||
"""Get list of all registered model types
|
||||
|
||||
Returns:
|
||||
List of registered model type identifiers
|
||||
"""
|
||||
return list(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, model_type: str) -> bool:
|
||||
"""Check if a model type is registered
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
True if the model type is registered, False otherwise
|
||||
"""
|
||||
return model_type in cls._services
|
||||
|
||||
@classmethod
|
||||
def clear_registrations(cls):
|
||||
"""Clear all registrations - mainly for testing purposes"""
|
||||
cls._services.clear()
|
||||
cls._routes.clear()
|
||||
cls._initialized_services.clear()
|
||||
cls._initialized_routes.clear()
|
||||
logger.info("Cleared all model type registrations")
|
||||
|
||||
|
||||
def register_default_model_types():
|
||||
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..routes.lora_routes import LoraRoutes
|
||||
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||
from ..routes.embedding_routes import EmbeddingRoutes
|
||||
|
||||
# Register LoRA model type
|
||||
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||
|
||||
# Register Checkpoint model type
|
||||
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||
|
||||
# Register Embedding model type
|
||||
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
|
||||
|
||||
logger.info("Registered default model types: lora, checkpoint, embedding")
|
||||
@@ -393,8 +393,8 @@ class RecipeScanner:
|
||||
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
||||
hash_value = lora['hash']
|
||||
|
||||
if self._lora_scanner.has_lora_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
|
||||
if self._lora_scanner.has_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
|
||||
if lora_path:
|
||||
file_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
lora['file_name'] = file_name
|
||||
@@ -465,7 +465,7 @@ class RecipeScanner:
|
||||
# Count occurrences of each base model
|
||||
for lora in loras:
|
||||
if 'hash' in lora:
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
|
||||
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
|
||||
if lora_path:
|
||||
base_model = await self._get_base_model_for_lora(lora_path)
|
||||
if base_model:
|
||||
@@ -603,9 +603,9 @@ class RecipeScanner:
|
||||
if 'loras' in item:
|
||||
for lora in item['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
result = {
|
||||
'items': paginated_items,
|
||||
@@ -655,9 +655,9 @@ class RecipeScanner:
|
||||
for lora in formatted_recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora_hash = lora['hash'].lower()
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
||||
|
||||
return formatted_recipe
|
||||
|
||||
|
||||
@@ -7,97 +7,209 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar('T') # Define a type variable for service types
|
||||
|
||||
class ServiceRegistry:
|
||||
"""Centralized registry for service singletons"""
|
||||
"""Central registry for managing singleton services"""
|
||||
|
||||
_instance = None
|
||||
_services: Dict[str, Any] = {}
|
||||
_lock = asyncio.Lock()
|
||||
_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get singleton instance of the registry"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
async def register_service(cls, name: str, service: Any) -> None:
|
||||
"""Register a service instance with the registry
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
service: Service instance to register
|
||||
"""
|
||||
cls._services[name] = service
|
||||
logger.debug(f"Registered service: {name}")
|
||||
|
||||
@classmethod
|
||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
||||
"""Register a service instance with the registry"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
registry._services[service_name] = service_instance
|
||||
logger.debug(f"Registered service: {service_name}")
|
||||
async def get_service(cls, name: str) -> Optional[Any]:
|
||||
"""Get a service instance by name
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
Service instance or None if not found
|
||||
"""
|
||||
return cls._services.get(name)
|
||||
|
||||
@classmethod
|
||||
async def get_service(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
if service_name not in registry._services:
|
||||
logger.debug(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
def get_service_sync(cls, name: str) -> Optional[Any]:
|
||||
"""Synchronously get a service instance by name
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
Service instance or None if not found
|
||||
"""
|
||||
return cls._services.get(name)
|
||||
|
||||
@classmethod
|
||||
def _get_lock(cls, name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a service
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
AsyncIO lock for the service
|
||||
"""
|
||||
if name not in cls._locks:
|
||||
cls._locks[name] = asyncio.Lock()
|
||||
return cls._locks[name]
|
||||
|
||||
# Convenience methods for common services
|
||||
@classmethod
|
||||
async def get_lora_scanner(cls):
|
||||
"""Get the LoraScanner instance"""
|
||||
from .lora_scanner import LoraScanner
|
||||
scanner = await cls.get_service("lora_scanner")
|
||||
if scanner is None:
|
||||
scanner = await LoraScanner.get_instance()
|
||||
await cls.register_service("lora_scanner", scanner)
|
||||
return scanner
|
||||
"""Get or create LoRA scanner instance"""
|
||||
service_name = "lora_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .lora_scanner import LoraScanner
|
||||
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_checkpoint_scanner(cls):
|
||||
"""Get the CheckpointScanner instance"""
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await cls.get_service("checkpoint_scanner")
|
||||
if scanner is None:
|
||||
"""Get or create Checkpoint scanner instance"""
|
||||
service_name = "checkpoint_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
await cls.register_service("checkpoint_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get the CivitaiClient instance"""
|
||||
from .civitai_client import CivitaiClient
|
||||
client = await cls.get_service("civitai_client")
|
||||
if client is None:
|
||||
client = await CivitaiClient.get_instance()
|
||||
await cls.register_service("civitai_client", client)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get the DownloadManager instance"""
|
||||
from .download_manager import DownloadManager
|
||||
manager = await cls.get_service("download_manager")
|
||||
if manager is None:
|
||||
manager = await DownloadManager.get_instance()
|
||||
await cls.register_service("download_manager", manager)
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_recipe_scanner(cls):
|
||||
"""Get the RecipeScanner instance"""
|
||||
from .recipe_scanner import RecipeScanner
|
||||
scanner = await cls.get_service("recipe_scanner")
|
||||
if scanner is None:
|
||||
lora_scanner = await cls.get_lora_scanner()
|
||||
scanner = RecipeScanner(lora_scanner)
|
||||
await cls.register_service("recipe_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
"""Get or create Recipe scanner instance"""
|
||||
service_name = "recipe_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .recipe_scanner import RecipeScanner
|
||||
|
||||
scanner = await RecipeScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get or create CivitAI client instance"""
|
||||
service_name = "civitai_client"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civitai_client import CivitaiClient
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get or create Download manager instance"""
|
||||
service_name = "download_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .download_manager import DownloadManager
|
||||
|
||||
manager = DownloadManager()
|
||||
cls._services[service_name] = manager
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return manager
|
||||
|
||||
@classmethod
|
||||
async def get_websocket_manager(cls):
|
||||
"""Get the WebSocketManager instance"""
|
||||
from .websocket_manager import ws_manager
|
||||
manager = await cls.get_service("websocket_manager")
|
||||
if manager is None:
|
||||
# ws_manager is already a global instance in websocket_manager.py
|
||||
"""Get or create WebSocket manager instance"""
|
||||
service_name = "websocket_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .websocket_manager import ws_manager
|
||||
await cls.register_service("websocket_manager", ws_manager)
|
||||
manager = ws_manager
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = ws_manager
|
||||
logger.debug(f"Registered {service_name}")
|
||||
return ws_manager
|
||||
|
||||
@classmethod
|
||||
async def get_embedding_scanner(cls):
|
||||
"""Get or create Embedding scanner instance"""
|
||||
service_name = "embedding_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .embedding_scanner import EmbeddingScanner
|
||||
|
||||
scanner = await EmbeddingScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
def clear_services(cls):
|
||||
"""Clear all registered services - mainly for testing"""
|
||||
cls._services.clear()
|
||||
cls._locks.clear()
|
||||
logger.info("Cleared all registered services")
|
||||
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Set, Dict, Optional
|
||||
from uuid import uuid4
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +13,9 @@ class WebSocketManager:
|
||||
def __init__(self):
|
||||
self._websockets: Set[web.WebSocketResponse] = set()
|
||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||
# Add progress tracking dictionary
|
||||
self._download_progress: Dict[str, Dict] = {}
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection"""
|
||||
@@ -39,21 +44,48 @@ class WebSocketManager:
|
||||
finally:
|
||||
self._init_websockets.discard(ws)
|
||||
return ws
|
||||
|
||||
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for checkpoint download progress"""
|
||||
|
||||
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for download progress"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self._checkpoint_websockets.add(ws)
|
||||
|
||||
# Get download_id from query parameters
|
||||
download_id = request.query.get('id')
|
||||
|
||||
if not download_id:
|
||||
# Generate a new download ID if not provided
|
||||
download_id = str(uuid4())
|
||||
|
||||
# Store the websocket with its download ID
|
||||
self._download_websockets[download_id] = ws
|
||||
|
||||
try:
|
||||
# Send the download ID back to the client
|
||||
await ws.send_json({
|
||||
'type': 'download_id',
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.ERROR:
|
||||
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
|
||||
logger.error(f'Download WebSocket error: {ws.exception()}')
|
||||
finally:
|
||||
self._checkpoint_websockets.discard(ws)
|
||||
if download_id in self._download_websockets:
|
||||
del self._download_websockets[download_id]
|
||||
|
||||
# Schedule cleanup of completed downloads after WebSocket disconnection
|
||||
asyncio.create_task(self._delayed_cleanup(download_id))
|
||||
return ws
|
||||
|
||||
|
||||
async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300):
|
||||
"""Clean up download progress after a delay (5 minutes by default)"""
|
||||
await asyncio.sleep(delay_seconds)
|
||||
progress_data = self._download_progress.get(download_id)
|
||||
if progress_data and progress_data.get('progress', 0) >= 100:
|
||||
self.cleanup_download_progress(download_id)
|
||||
logger.debug(f"Delayed cleanup completed for download {download_id}")
|
||||
|
||||
async def broadcast(self, data: Dict):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self._websockets:
|
||||
@@ -84,17 +116,45 @@ class WebSocketManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending initialization progress: {e}")
|
||||
|
||||
async def broadcast_checkpoint_progress(self, data: Dict):
|
||||
"""Broadcast checkpoint download progress to connected clients"""
|
||||
if not self._checkpoint_websockets:
|
||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||
"""Send progress update to specific download client"""
|
||||
# Store simplified progress data in memory (only progress percentage)
|
||||
self._download_progress[download_id] = {
|
||||
'progress': data.get('progress', 0),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
if download_id not in self._download_websockets:
|
||||
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||
return
|
||||
|
||||
for ws in self._checkpoint_websockets:
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending checkpoint progress: {e}")
|
||||
|
||||
ws = self._download_websockets[download_id]
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download progress: {e}")
|
||||
|
||||
def get_download_progress(self, download_id: str) -> Optional[Dict]:
|
||||
"""Get progress information for a specific download"""
|
||||
return self._download_progress.get(download_id)
|
||||
|
||||
def cleanup_download_progress(self, download_id: str):
|
||||
"""Remove progress info for a specific download"""
|
||||
self._download_progress.pop(download_id, None)
|
||||
|
||||
def cleanup_old_downloads(self, max_age_hours: int = 24):
|
||||
"""Clean up old download progress entries"""
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
to_remove = []
|
||||
|
||||
for download_id, progress_data in self._download_progress.items():
|
||||
if progress_data.get('timestamp', datetime.now()) < cutoff_time:
|
||||
to_remove.append(download_id)
|
||||
|
||||
for download_id in to_remove:
|
||||
self._download_progress.pop(download_id, None)
|
||||
logger.debug(f"Cleaned up old download progress for {download_id}")
|
||||
|
||||
def get_connected_clients_count(self) -> int:
|
||||
"""Get number of connected clients"""
|
||||
return len(self._websockets)
|
||||
@@ -102,10 +162,14 @@ class WebSocketManager:
|
||||
def get_init_clients_count(self) -> int:
|
||||
"""Get number of initialization progress clients"""
|
||||
return len(self._init_websockets)
|
||||
|
||||
def get_checkpoint_clients_count(self) -> int:
|
||||
"""Get number of checkpoint progress clients"""
|
||||
return len(self._checkpoint_websockets)
|
||||
|
||||
def get_download_clients_count(self) -> int:
|
||||
"""Get number of download progress clients"""
|
||||
return len(self._download_websockets)
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
"""Generate a unique download ID"""
|
||||
return str(uuid4())
|
||||
|
||||
# Global instance
|
||||
ws_manager = WebSocketManager()
|
||||
@@ -10,7 +10,8 @@ NSFW_LEVELS = {
|
||||
# Node type constants
|
||||
NODE_TYPES = {
|
||||
"Lora Loader (LoraManager)": 1,
|
||||
"Lora Stacker (LoraManager)": 2
|
||||
"Lora Stacker (LoraManager)": 2,
|
||||
"WanVideo Lora Select (LoraManager)": 3
|
||||
}
|
||||
|
||||
# Default ComfyUI node color when bgcolor is null
|
||||
@@ -45,4 +46,11 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
||||
}
|
||||
|
||||
# Valid Lora types
|
||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||
|
||||
# Civitai model tags in priority order for subfolder organization
|
||||
CIVITAI_MODEL_TAGS = [
|
||||
'character', 'style', 'concept', 'clothing', 'base model',
|
||||
'poses', 'background', 'tool', 'vehicle', 'buildings',
|
||||
'objects', 'assets', 'animal', 'action'
|
||||
]
|
||||
@@ -214,6 +214,10 @@ class DownloadManager:
|
||||
if 'checkpoint' in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(('checkpoint', checkpoint_scanner))
|
||||
|
||||
if 'embedding' in model_types:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(('embedding', embedding_scanner))
|
||||
|
||||
# Get all models
|
||||
all_models = []
|
||||
|
||||
@@ -251,12 +251,13 @@ class ExampleImagesProcessor:
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner]:
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
cache = await scan_obj.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('sha256') == model_hash:
|
||||
@@ -384,12 +385,13 @@ class ExampleImagesProcessor:
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner]:
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
if scan_obj.has_hash(model_hash):
|
||||
cache = await scan_obj.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
@@ -196,7 +197,7 @@ class MetadataManager:
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=sha256,
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
@@ -205,13 +206,27 @@ class MetadataManager:
|
||||
model_type="checkpoint",
|
||||
from_civitai=True
|
||||
)
|
||||
elif model_class.__name__ == "EmbeddingMetadata":
|
||||
metadata = model_class(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=sha256,
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription="",
|
||||
from_civitai=True
|
||||
)
|
||||
else: # Default to LoraMetadata
|
||||
metadata = model_class(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=sha256,
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
@@ -222,7 +237,7 @@ class MetadataManager:
|
||||
)
|
||||
|
||||
# Try to extract model-specific metadata
|
||||
await MetadataManager._enrich_metadata(metadata, real_path)
|
||||
# await MetadataManager._enrich_metadata(metadata, real_path)
|
||||
|
||||
# Save the created metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata, create_backup=False)
|
||||
@@ -266,6 +281,12 @@ class MetadataManager:
|
||||
"""
|
||||
need_update = False
|
||||
|
||||
# Check if file_name matches the actual file name
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
if metadata.file_name != base_name:
|
||||
metadata.file_name = base_name
|
||||
need_update = True
|
||||
|
||||
# Check if file path is different from what's in metadata
|
||||
if normalize_path(file_path) != metadata.file_path:
|
||||
metadata.file_path = normalize_path(file_path)
|
||||
|
||||
@@ -11,7 +11,7 @@ class BaseModelMetadata:
|
||||
model_name: str # The model's name defined by the creator
|
||||
file_path: str # Full path to the model file
|
||||
size: int # File size in bytes
|
||||
modified: float # Last modified timestamp
|
||||
modified: float # Timestamp when the model was added to the management system
|
||||
sha256: str # SHA256 hash of the file
|
||||
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
|
||||
preview_url: str # Preview image URL
|
||||
@@ -73,11 +73,6 @@ class BaseModelMetadata:
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def modified_datetime(self) -> datetime:
|
||||
"""Convert modified timestamp to datetime object"""
|
||||
return datetime.fromtimestamp(self.modified)
|
||||
|
||||
def update_civitai_info(self, civitai_data: Dict) -> None:
|
||||
"""Update Civitai information"""
|
||||
self.civitai = civitai_data
|
||||
@@ -128,7 +123,7 @@ class LoraMetadata(BaseModelMetadata):
|
||||
@dataclass
|
||||
class CheckpointMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for a Checkpoint model"""
|
||||
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
|
||||
model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
|
||||
@@ -163,3 +158,41 @@ class CheckpointMetadata(BaseModelMetadata):
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class EmbeddingMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for an Embedding model"""
|
||||
model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
|
||||
"""Create EmbeddingMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
model_type = version_info.get('type', 'embedding')
|
||||
|
||||
# Extract tags and description if available
|
||||
tags = []
|
||||
description = ""
|
||||
if 'model' in version_info:
|
||||
if 'tags' in version_info['model']:
|
||||
tags = version_info['model']['tags']
|
||||
if 'description' in version_info['model']:
|
||||
description = version_info['model']['description']
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0,
|
||||
from_civitai=True,
|
||||
civitai=version_info,
|
||||
model_type=model_type,
|
||||
tags=tags,
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ from .model_utils import determine_base_model
|
||||
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..config import config
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -327,8 +329,6 @@ class ModelRouteUtils:
|
||||
# Update hash index if available
|
||||
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
await scanner._save_cache_to_disk()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -551,8 +551,6 @@ class ModelRouteUtils:
|
||||
|
||||
# Add to excluded models list
|
||||
scanner._excluded_models.append(file_path)
|
||||
|
||||
await scanner._save_cache_to_disk()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -564,66 +562,83 @@ class ModelRouteUtils:
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
|
||||
"""Handle model download request
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
download_manager: Instance of DownloadManager
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
async def handle_download_model(request: web.Request) -> web.Response:
|
||||
"""Handle model download request"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
# Get or generate a download ID
|
||||
download_id = data.get('download_id', ws_manager.generate_download_id())
|
||||
|
||||
# Create progress callback with download ID
|
||||
async def progress_callback(progress):
|
||||
from ..services.websocket_manager import ws_manager
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_download_progress(download_id, {
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
'progress': progress,
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
# Check which identifier is provided and convert to int
|
||||
try:
|
||||
model_id = int(data.get('model_id'))
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Invalid model_id: Must be an integer"
|
||||
}, status=400)
|
||||
|
||||
# Convert model_version_id to int if provided
|
||||
model_version_id = None
|
||||
if data.get('model_version_id'):
|
||||
try:
|
||||
model_version_id = int(data.get('model_version_id'))
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Invalid model_version_id: Must be an integer"
|
||||
}, status=400)
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
# Only model_id is required, model_version_id is optional
|
||||
if not model_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Missing required parameter: Please provide 'model_id'"
|
||||
}, status=400)
|
||||
|
||||
# Use the correct root directory based on model type
|
||||
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
|
||||
save_dir = data.get(root_key)
|
||||
use_default_paths = data.get('use_default_paths', False)
|
||||
|
||||
# Pass the download_id to download_from_civitai
|
||||
result = await download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=save_dir,
|
||||
save_dir=data.get('model_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
use_default_paths=use_default_paths,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
download_id=download_id # Pass download_id explicitly
|
||||
)
|
||||
|
||||
# Include download_id in the response
|
||||
result['download_id'] = download_id
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401, # Use 401 status code to match Civitai's response
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Early Access Restriction: {error_message}",
|
||||
'download_id': download_id
|
||||
}, status=401)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': error_message,
|
||||
'download_id': download_id
|
||||
}, status=500)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
@@ -633,13 +648,75 @@ class ModelRouteUtils:
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
}, status=401)
|
||||
|
||||
logger.error(f"Error downloading {model_type}: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
logger.error(f"Error downloading model: {error_message}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': error_message
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_cancel_download(request: web.Request) -> web.Response:
|
||||
"""Handle cancellation of a download task
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
download_id = request.match_info.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID is required'
|
||||
}, status=400)
|
||||
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
# Notify clients about cancellation via WebSocket
|
||||
await ws_manager.broadcast_download_progress(download_id, {
|
||||
'status': 'cancelled',
|
||||
'progress': 0,
|
||||
'download_id': download_id,
|
||||
'message': 'Download cancelled by user'
|
||||
})
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_list_downloads(request: web.Request) -> web.Response:
|
||||
"""Get list of active downloads
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response with list of downloads
|
||||
"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
result = await download_manager.get_active_downloads()
|
||||
return web.json_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing downloads: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
|
||||
@@ -693,8 +770,10 @@ class ModelRouteUtils:
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
model_id = data.get('model_id')
|
||||
model_version_id = data.get('model_version_id')
|
||||
model_id = int(data.get('model_id'))
|
||||
model_version_id = None
|
||||
if data.get('model_version_id'):
|
||||
model_version_id = int(data.get('model_version_id'))
|
||||
|
||||
if not file_path or not model_id:
|
||||
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
|
||||
@@ -955,6 +1034,7 @@ class ModelRouteUtils:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'new_file_path': new_file_path,
|
||||
'new_preview_path': config.get_preview_static_url(new_preview),
|
||||
'renamed_files': renamed_files,
|
||||
'reload_required': False
|
||||
})
|
||||
@@ -965,3 +1045,56 @@ class ModelRouteUtils:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle saving metadata updates
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Handle nested updates (for civitai.trainedWords)
|
||||
for key, value in metadata_updates.items():
|
||||
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||
# Deep update for nested dictionaries
|
||||
for nested_key, nested_value in value.items():
|
||||
metadata[key][nested_key] = nested_value
|
||||
else:
|
||||
# Regular update for top-level keys
|
||||
metadata[key] = value
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@@ -1,8 +1,54 @@
|
||||
from difflib import SequenceMatcher
|
||||
import requests
|
||||
import tempfile
|
||||
import re
|
||||
import os
|
||||
from bs4 import BeautifulSoup
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
import asyncio
|
||||
|
||||
def get_lora_info(lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
async def _get_lora_info_async():
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, []
|
||||
|
||||
try:
|
||||
# Check if we're already in an event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in a running loop, we need to use a different approach
|
||||
# Create a new thread to run the async code
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_lora_info_async())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No event loop is running, we can use asyncio.run()
|
||||
return asyncio.run(_get_lora_info_async())
|
||||
|
||||
def download_twitter_image(url):
|
||||
"""Download image from a URL containing twitter:image meta tag
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
[project]
|
||||
name = "comfyui-lora-manager"
|
||||
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
||||
version = "0.8.19"
|
||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||
version = "0.8.21"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
"jinja2",
|
||||
"safetensors",
|
||||
"watchdog",
|
||||
"beautifulsoup4",
|
||||
"piexif",
|
||||
"Pillow",
|
||||
@@ -15,7 +14,7 @@ dependencies = [
|
||||
"requests",
|
||||
"toml",
|
||||
"natsort",
|
||||
"msgpack"
|
||||
"GitPython"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
aiohttp
|
||||
jinja2
|
||||
safetensors
|
||||
watchdog
|
||||
beautifulsoup4
|
||||
piexif
|
||||
Pillow
|
||||
@@ -9,6 +8,6 @@ olefile
|
||||
requests
|
||||
toml
|
||||
numpy
|
||||
torch
|
||||
natsort
|
||||
msgpack
|
||||
pyyaml
|
||||
GitPython
|
||||
|
||||
114
standalone.py
114
standalone.py
@@ -3,6 +3,26 @@ import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Create mock modules for py/nodes directory - add this before any other imports
|
||||
def mock_nodes_directory():
|
||||
"""Create mock modules for all Python files in the py/nodes directory"""
|
||||
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
|
||||
if os.path.exists(nodes_dir):
|
||||
# Create a mock module for the nodes package itself
|
||||
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
|
||||
|
||||
# Create mock modules for all Python files in the nodes directory
|
||||
for file in os.listdir(nodes_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
module_name = file[:-3] # Remove .py extension
|
||||
full_module_name = f'py.nodes.{module_name}'
|
||||
# Create empty module object
|
||||
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
|
||||
print(f"Created mock module for: {full_module_name}")
|
||||
|
||||
# Run the mocking function before any other imports
|
||||
mock_nodes_directory()
|
||||
|
||||
# Create mock folder_paths module BEFORE any other imports
|
||||
class MockFolderPaths:
|
||||
@staticmethod
|
||||
@@ -86,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone")
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
# Now we can import the global config from our local modules
|
||||
from py.config import config
|
||||
|
||||
@@ -98,17 +134,6 @@ class StandaloneServer:
|
||||
|
||||
# Ensure the app's access logger is configured to reduce verbosity
|
||||
self.app._subapps = [] # Ensure this exists to avoid AttributeError
|
||||
|
||||
# Configure access logging for the app
|
||||
self.app.on_startup.append(self._configure_access_logger)
|
||||
|
||||
async def _configure_access_logger(self, app):
|
||||
"""Configure access logger to reduce verbosity"""
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# If using aiohttp>=3.8.0, configure access logger through app directly
|
||||
if hasattr(app, 'access_logger'):
|
||||
app.access_logger.setLevel(logging.WARNING)
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the standalone server"""
|
||||
@@ -198,9 +223,6 @@ class StandaloneLoraManager(LoraManager):
|
||||
|
||||
# Store app in a global-like location for compatibility
|
||||
sys.modules['server'].PromptServer.instance = server_instance
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
@@ -232,7 +254,7 @@ class StandaloneLoraManager(LoraManager):
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Checkpoint root path does not exist: {root}")
|
||||
continue
|
||||
@@ -257,23 +279,50 @@ class StandaloneLoraManager(LoraManager):
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(getattr(config, "embeddings_roots", []), start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Embedding root path does not exist: {root}")
|
||||
continue
|
||||
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
for target, link in config._path_mappings.items():
|
||||
if os.path.normpath(link) == os.path.normpath(root):
|
||||
real_root = target
|
||||
break
|
||||
|
||||
display_root = real_root.replace('\\', '/')
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {display_root}")
|
||||
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for symlink target paths that aren't already covered
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
norm_target = os.path.normpath(target_path)
|
||||
if norm_target not in added_targets:
|
||||
# Determine if this is a checkpoint or lora link based on path
|
||||
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots)
|
||||
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots)
|
||||
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
|
||||
is_embedding = any(os.path.normpath(emb_root) in os.path.normpath(link_path) for emb_root in getattr(config, "embeddings_roots", []))
|
||||
is_embedding = is_embedding or any(os.path.normpath(emb_root) in norm_target for emb_root in getattr(config, "embeddings_roots", []))
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
@@ -294,35 +343,39 @@ class StandaloneLoraManager(LoraManager):
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
from py.routes.lora_routes import LoraRoutes
|
||||
from py.routes.api_routes import ApiRoutes
|
||||
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from py.routes.recipe_routes import RecipeRoutes
|
||||
from py.routes.checkpoints_routes import CheckpointsRoutes
|
||||
from py.routes.update_routes import UpdateRoutes
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||
from py.routes.stats_routes import StatsRoutes
|
||||
from py.services.websocket_manager import ws_manager
|
||||
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
|
||||
register_default_model_types()
|
||||
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
stats_routes = StatsRoutes()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
stats_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
@@ -347,9 +400,6 @@ async def main():
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Explicitly configure aiohttp access logger regardless of selected log level
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Create the server instance
|
||||
server = StandaloneServer()
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ html, body {
|
||||
--lora-border: oklch(90% 0.02 256 / 0.15);
|
||||
--lora-text: oklch(95% 0.02 256);
|
||||
--lora-error: oklch(75% 0.32 29);
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Modified to be used with oklch() */
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* New green success color */
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
|
||||
|
||||
/* Spacing Scale */
|
||||
--space-1: calc(8px * 1);
|
||||
@@ -70,6 +70,11 @@ html, body {
|
||||
--border-radius-xs: 4px;
|
||||
|
||||
--scrollbar-width: 8px; /* 添加滚动条宽度变量 */
|
||||
|
||||
/* Shortcut styles */
|
||||
--shortcut-bg: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.12);
|
||||
--shortcut-border: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.25);
|
||||
--shortcut-text: var(--text-color);
|
||||
}
|
||||
|
||||
html[data-theme="dark"] {
|
||||
|
||||
@@ -73,12 +73,12 @@
|
||||
}
|
||||
|
||||
/* Style for selected cards */
|
||||
.lora-card.selected {
|
||||
.model-card.selected {
|
||||
box-shadow: 0 0 0 2px var(--lora-accent);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.lora-card.selected::after {
|
||||
.model-card.selected::after {
|
||||
content: "✓";
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
box-sizing: border-box; /* Include padding in width calculation */
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-base);
|
||||
@@ -30,12 +30,12 @@
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.lora-card:hover {
|
||||
.model-card:hover {
|
||||
transform: translateY(-2px);
|
||||
background: oklch(100% 0 0 / 0.6);
|
||||
}
|
||||
|
||||
.lora-card:focus-visible {
|
||||
.model-card:focus-visible {
|
||||
outline: 2px solid var(--lora-accent);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
@@ -47,7 +47,7 @@
|
||||
grid-template-columns: repeat(auto-fill, minmax(270px, 1fr));
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 270px;
|
||||
}
|
||||
}
|
||||
@@ -59,7 +59,7 @@
|
||||
grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 280px;
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@
|
||||
grid-template-columns: repeat(auto-fill, minmax(240px, 1fr));
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 240px;
|
||||
}
|
||||
}
|
||||
@@ -259,8 +259,8 @@
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.hover-reveal .lora-card:hover .card-header,
|
||||
.hover-reveal .lora-card:hover .card-footer {
|
||||
.hover-reveal .model-card:hover .card-header,
|
||||
.hover-reveal .model-card:hover .card-footer {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@
|
||||
grid-template-columns: minmax(260px, 1fr); /* Adjusted minimum size for mobile */
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 100%; /* Allow cards to fill available space on mobile */
|
||||
}
|
||||
}
|
||||
@@ -425,8 +425,8 @@
|
||||
}
|
||||
|
||||
/* Prevent text selection on cards and interactive elements */
|
||||
.lora-card,
|
||||
.lora-card *,
|
||||
.model-card,
|
||||
.model-card *,
|
||||
.card-actions,
|
||||
.card-actions i,
|
||||
.toggle-blur-btn,
|
||||
@@ -510,7 +510,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
/* Add after the existing .lora-card:hover styles */
|
||||
/* Add after the existing .model-card:hover styles */
|
||||
|
||||
@keyframes update-pulse {
|
||||
0% { box-shadow: 0 0 0 0 var(--lora-accent-transparent); }
|
||||
@@ -523,7 +523,7 @@
|
||||
--lora-accent-transparent: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.6);
|
||||
}
|
||||
|
||||
.lora-card.updated {
|
||||
.model-card.updated {
|
||||
animation: update-pulse 1.2s ease-out;
|
||||
}
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@
|
||||
}
|
||||
|
||||
/* Make cards in duplicate groups have consistent width */
|
||||
.card-group-container .lora-card {
|
||||
.card-group-container .model-card {
|
||||
flex: 0 0 auto;
|
||||
width: 240px;
|
||||
margin: 0;
|
||||
@@ -241,26 +241,26 @@
|
||||
}
|
||||
|
||||
/* Duplicate card styling */
|
||||
.lora-card.duplicate {
|
||||
.model-card.duplicate {
|
||||
position: relative;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.lora-card.duplicate:hover {
|
||||
.model-card.duplicate:hover {
|
||||
border-color: var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h);
|
||||
}
|
||||
|
||||
.lora-card.duplicate.latest {
|
||||
.model-card.duplicate.latest {
|
||||
border-style: solid;
|
||||
border-color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
}
|
||||
|
||||
.lora-card.duplicate-selected {
|
||||
.model-card.duplicate-selected {
|
||||
border: 2px solid oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
|
||||
box-shadow: 0 0 8px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.lora-card .selector-checkbox {
|
||||
.model-card .selector-checkbox {
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
right: 10px;
|
||||
@@ -271,7 +271,7 @@
|
||||
}
|
||||
|
||||
/* Latest indicator */
|
||||
.lora-card.duplicate.latest::after {
|
||||
.model-card.duplicate.latest::after {
|
||||
content: "Latest";
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
@@ -365,13 +365,13 @@
|
||||
}
|
||||
|
||||
/* Hash Mismatch Styling */
|
||||
.lora-card.duplicate.hash-mismatch {
|
||||
.model-card.duplicate.hash-mismatch {
|
||||
border: 2px dashed oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
opacity: 0.85;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.lora-card.duplicate.hash-mismatch::before {
|
||||
.model-card.duplicate.hash-mismatch::before {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
@@ -389,7 +389,7 @@
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.lora-card.duplicate.hash-mismatch .card-preview {
|
||||
.model-card.duplicate.hash-mismatch .card-preview {
|
||||
filter: grayscale(20%);
|
||||
}
|
||||
|
||||
@@ -407,7 +407,7 @@
|
||||
}
|
||||
|
||||
/* Disabled checkbox style */
|
||||
.lora-card.duplicate.hash-mismatch .selector-checkbox {
|
||||
.model-card.duplicate.hash-mismatch .selector-checkbox {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
align-items: center;
|
||||
text-decoration: none;
|
||||
color: var(--text-color);
|
||||
gap: 8px;
|
||||
gap: 2px;
|
||||
}
|
||||
|
||||
.app-logo {
|
||||
@@ -223,11 +223,6 @@
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.update-badge.hidden,
|
||||
.update-badge:not(.visible) {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Mobile adjustments */
|
||||
@media (max-width: 768px) {
|
||||
.app-title {
|
||||
|
||||
@@ -109,7 +109,7 @@
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.lora-card,
|
||||
.model-card,
|
||||
.progress-bar,
|
||||
.current-item-bar {
|
||||
transition: none;
|
||||
|
||||
@@ -123,42 +123,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
/* 修改 back-to-top 按钮样式,使其固定在 modal 内部 */
|
||||
.modal-content .back-to-top {
|
||||
position: sticky; /* 改用 sticky 定位 */
|
||||
float: right; /* 使用 float 确保按钮在右侧 */
|
||||
bottom: 20px; /* 距离底部的距离 */
|
||||
margin-right: 20px; /* 右侧间距 */
|
||||
margin-top: -56px; /* 负边距确保不占用额外空间 */
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transform: translateY(10px);
|
||||
transition: all 0.3s ease;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.modal-content .back-to-top.visible {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.modal-content .back-to-top:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
/* File name copy styles */
|
||||
.file-name-wrapper {
|
||||
display: flex;
|
||||
@@ -436,22 +400,24 @@
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
margin-bottom: var(--space-1);
|
||||
padding: 6px 10px;
|
||||
padding: 2px 10px;
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-sm);
|
||||
max-width: fit-content;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .creator-info {
|
||||
[data-theme="dark"] .creator-info,
|
||||
[data-theme="dark"] .civitai-view {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.creator-avatar {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
width: 26px;
|
||||
height: 26px;
|
||||
border-radius: 50%;
|
||||
overflow: hidden;
|
||||
flex-shrink: 0;
|
||||
@@ -482,8 +448,40 @@
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Optional: add hover effect for creator info */
|
||||
.creator-info:hover {
|
||||
/* Add hover effect for creator info */
|
||||
.creator-info:hover,
|
||||
.civitai-view:hover {
|
||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||
border-color: var(--lora-accent);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.creator-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
margin-bottom: var(--space-1);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.civitai-view {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-sm);
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
font-size: 0.9em;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.civitai-view i {
|
||||
font-size: 20px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
@@ -4,31 +4,254 @@
|
||||
margin-top: var(--space-4);
|
||||
}
|
||||
|
||||
.carousel {
|
||||
transition: max-height 0.3s ease-in-out;
|
||||
/* Main showcase container */
|
||||
.showcase-container {
|
||||
display: flex;
|
||||
height: 750px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
overflow: hidden;
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.carousel.collapsed {
|
||||
max-height: 0;
|
||||
.showcase-container.empty {
|
||||
height: 400px;
|
||||
}
|
||||
|
||||
.carousel-container {
|
||||
/* Thumbnail Sidebar */
|
||||
.thumbnail-sidebar {
|
||||
width: 200px;
|
||||
background: var(--bg-color);
|
||||
border-right: 1px solid var(--border-color);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.thumbnail-grid {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
scrollbar-width: none; /* Firefox */
|
||||
-ms-overflow-style: none; /* Internet Explorer 10+ */
|
||||
padding: var(--space-2);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.thumbnail-grid::-webkit-scrollbar {
|
||||
display: none; /* WebKit */
|
||||
}
|
||||
|
||||
.thumbnail-item {
|
||||
position: relative;
|
||||
aspect-ratio: 1;
|
||||
border-radius: var(--border-radius-xs);
|
||||
overflow: hidden;
|
||||
cursor: pointer;
|
||||
border: 2px solid transparent;
|
||||
transition: all 0.2s ease;
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.thumbnail-item:hover {
|
||||
border-color: var(--lora-accent);
|
||||
transform: scale(1.02);
|
||||
}
|
||||
|
||||
.thumbnail-item.active {
|
||||
border-color: var(--lora-accent);
|
||||
box-shadow: 0 0 0 1px var(--lora-accent);
|
||||
}
|
||||
|
||||
.thumbnail-media {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.thumbnail-media.blurred {
|
||||
filter: blur(8px);
|
||||
}
|
||||
|
||||
.video-indicator {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
color: white;
|
||||
background: rgba(0, 0, 0, 0.6);
|
||||
border-radius: 50%;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 0.7em;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.thumbnail-nsfw-overlay {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: white;
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
/* Import Section */
|
||||
.import-section {
|
||||
padding: var(--space-2);
|
||||
border-top: 1px solid var(--border-color);
|
||||
background: var(--bg-color);
|
||||
}
|
||||
|
||||
.select-files-btn {
|
||||
width: 100%;
|
||||
background: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
border: none;
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-2);
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 6px;
|
||||
transition: all 0.2s;
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.select-files-btn:hover {
|
||||
opacity: 0.9;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.import-drop-zone {
|
||||
border: 2px dashed var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-2);
|
||||
text-align: center;
|
||||
transition: all 0.3s ease;
|
||||
background: var(--lora-surface);
|
||||
min-height: 60px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.import-drop-zone.highlight {
|
||||
border-color: var(--lora-accent);
|
||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||
}
|
||||
|
||||
.drop-zone-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
color: var(--text-color);
|
||||
opacity: 0.6;
|
||||
font-size: 0.8em;
|
||||
}
|
||||
|
||||
.drop-zone-content i {
|
||||
font-size: 1.2em;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
|
||||
/* Main Display Area */
|
||||
.main-display-area {
|
||||
flex: 1;
|
||||
position: relative;
|
||||
background: var(--card-bg);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.main-display-area.empty {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.empty-state {
|
||||
text-align: center;
|
||||
color: var(--text-color);
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
.empty-state i {
|
||||
font-size: 3em;
|
||||
margin-bottom: var(--space-2);
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.empty-state h3 {
|
||||
margin: 0 0 var(--space-1);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.empty-state p {
|
||||
margin: 0;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.navigation-controls {
|
||||
position: absolute;
|
||||
top: var(--space-2);
|
||||
right: var(--space-2);
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.nav-btn {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.nav-btn:hover {
|
||||
opacity: 1;
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 3px 7px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.nav-btn.info-btn.active {
|
||||
background: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.main-media-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.media-wrapper {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: var(--lora-surface);
|
||||
margin-bottom: var(--space-2);
|
||||
overflow: hidden; /* Ensure metadata panel is contained */
|
||||
}
|
||||
|
||||
.media-wrapper:last-child {
|
||||
margin-bottom: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.media-wrapper img,
|
||||
@@ -41,50 +264,11 @@
|
||||
object-fit: contain;
|
||||
}
|
||||
|
||||
.no-examples {
|
||||
text-align: center;
|
||||
padding: var(--space-3);
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
/* Adjust the media wrapper for tab system */
|
||||
#showcase-tab .carousel-container {
|
||||
margin-top: var(--space-2);
|
||||
}
|
||||
|
||||
/* Add styles for blurred showcase content */
|
||||
.nsfw-media-wrapper {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.media-wrapper img.blurred,
|
||||
.media-wrapper video.blurred {
|
||||
filter: blur(25px);
|
||||
}
|
||||
|
||||
.media-wrapper .nsfw-overlay {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 2;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
/* Position the toggle button at the top left of showcase media */
|
||||
.showcase-toggle-btn {
|
||||
position: absolute;
|
||||
z-index: 3;
|
||||
}
|
||||
|
||||
/* Add styles for showcase media controls */
|
||||
/* Media Controls for main display */
|
||||
.media-controls {
|
||||
position: absolute;
|
||||
top: var(--space-2);
|
||||
left: var(--space-2);
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
z-index: 4;
|
||||
@@ -94,15 +278,15 @@
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.media-controls.visible {
|
||||
.media-wrapper:hover .media-controls {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
pointer-events: auto;
|
||||
}
|
||||
|
||||
.media-control-btn {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 50%;
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--border-color);
|
||||
@@ -135,13 +319,11 @@
|
||||
border-color: var(--lora-error);
|
||||
}
|
||||
|
||||
/* Disabled state for delete button */
|
||||
.media-control-btn.example-delete-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Two-step confirmation for delete button */
|
||||
.media-control-btn.example-delete-btn .confirm-icon {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
@@ -172,16 +354,29 @@
|
||||
border-color: var(--lora-error);
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
box-shadow: 0 0 0 0 rgba(220, 53, 69, 0.7);
|
||||
}
|
||||
70% {
|
||||
box-shadow: 0 0 0 5px rgba(220, 53, 69, 0);
|
||||
}
|
||||
100% {
|
||||
box-shadow: 0 0 0 0 rgba(220, 53, 69, 0);
|
||||
}
|
||||
/* Toggle blur button for main display */
|
||||
.showcase-toggle-btn {
|
||||
position: absolute;
|
||||
top: calc(var(--space-2) + 44px);
|
||||
left: var(--space-2);
|
||||
z-index: 3;
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 50%;
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.15);
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.media-wrapper:hover .showcase-toggle-btn {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Image Metadata Panel Styles */
|
||||
@@ -195,22 +390,20 @@
|
||||
padding: var(--space-2);
|
||||
transform: translateY(100%);
|
||||
transition: transform 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275), opacity 0.25s ease;
|
||||
z-index: 5;
|
||||
max-height: 50%; /* Reduced to take less space */
|
||||
z-index: 15;
|
||||
max-height: 50%;
|
||||
overflow-y: auto;
|
||||
box-shadow: 0 -2px 8px rgba(0, 0, 0, 0.1);
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
/* Show metadata panel only when the 'visible' class is added */
|
||||
.media-wrapper .image-metadata-panel.visible {
|
||||
.image-metadata-panel.visible {
|
||||
transform: translateY(0);
|
||||
opacity: 0.98;
|
||||
pointer-events: auto;
|
||||
}
|
||||
|
||||
/* Adjust to dark theme */
|
||||
[data-theme="dark"] .image-metadata-panel {
|
||||
background: var(--card-bg);
|
||||
box-shadow: 0 -2px 8px rgba(0, 0, 0, 0.3);
|
||||
@@ -222,7 +415,6 @@
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
/* Styling for parameters tags */
|
||||
.params-tags {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
@@ -255,7 +447,6 @@
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Special styling for prompt row */
|
||||
.metadata-row.prompt-row {
|
||||
flex-direction: column;
|
||||
padding-top: 0;
|
||||
@@ -281,7 +472,7 @@
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: 6px 30px 6px 8px;
|
||||
margin-top: 2px;
|
||||
max-height: 80px; /* Reduced from 120px */
|
||||
max-height: 80px;
|
||||
overflow-y: auto;
|
||||
word-break: break-word;
|
||||
width: 100%;
|
||||
@@ -313,27 +504,6 @@
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Scrollbar styling for metadata panel */
|
||||
.image-metadata-panel::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.image-metadata-panel::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.image-metadata-panel::-webkit-scrollbar-thumb {
|
||||
background-color: var(--border-color);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
/* For Firefox */
|
||||
.image-metadata-panel {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: var(--border-color) transparent;
|
||||
}
|
||||
|
||||
/* No metadata message styling */
|
||||
.no-metadata-message {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
@@ -352,31 +522,66 @@
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Scroll Indicator */
|
||||
.scroll-indicator {
|
||||
cursor: pointer;
|
||||
padding: var(--space-2);
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
/* Scrollbar styling for metadata panel */
|
||||
.image-metadata-panel::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.image-metadata-panel::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.image-metadata-panel::-webkit-scrollbar-thumb {
|
||||
background-color: var(--border-color);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.image-metadata-panel {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: var(--border-color) transparent;
|
||||
}
|
||||
|
||||
/* NSFW Content Styles */
|
||||
.media-wrapper img.blurred,
|
||||
.media-wrapper video.blurred {
|
||||
filter: blur(25px);
|
||||
}
|
||||
|
||||
.media-wrapper .nsfw-overlay {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
z-index: 2;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
/* NSFW Filter Notification */
|
||||
.nsfw-filter-notification {
|
||||
background: var(--lora-warning);
|
||||
color: var(--lora-text);
|
||||
padding: var(--space-2);
|
||||
border-radius: var(--border-radius-xs);
|
||||
margin-bottom: var(--space-2);
|
||||
transition: background-color 0.2s, transform 0.2s;
|
||||
}
|
||||
|
||||
.scroll-indicator:hover {
|
||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.scroll-indicator span {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* No examples message */
|
||||
.no-examples {
|
||||
text-align: center;
|
||||
padding: var(--space-4);
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
/* Lazy loading */
|
||||
.lazy {
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s;
|
||||
@@ -386,93 +591,24 @@
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Example Import Area */
|
||||
.example-import-area {
|
||||
margin-top: var(--space-4);
|
||||
padding: var(--space-2);
|
||||
}
|
||||
|
||||
.example-import-area.empty {
|
||||
margin-top: var(--space-2);
|
||||
padding: var(--space-4) var(--space-2);
|
||||
}
|
||||
|
||||
.import-container {
|
||||
border: 2px dashed var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-4);
|
||||
text-align: center;
|
||||
transition: all 0.3s ease;
|
||||
background: var(--lora-surface);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.import-container.highlight {
|
||||
border-color: var(--lora-accent);
|
||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||
transform: scale(1.01);
|
||||
}
|
||||
|
||||
.import-placeholder {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: var(--space-1);
|
||||
padding-top: var(--space-1);
|
||||
}
|
||||
|
||||
.import-placeholder i {
|
||||
font-size: 2.5rem;
|
||||
/* color: var(--lora-accent); */
|
||||
opacity: 0.8;
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.import-placeholder h3 {
|
||||
margin: 0 0 var(--space-1);
|
||||
font-size: 1.2rem;
|
||||
font-weight: 500;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.import-placeholder p {
|
||||
margin: var(--space-1) 0;
|
||||
color: var(--text-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.import-placeholder .sub-text {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.6;
|
||||
margin: var(--space-1) 0;
|
||||
}
|
||||
|
||||
.import-formats {
|
||||
font-size: 0.8em !important;
|
||||
opacity: 0.6 !important;
|
||||
margin-top: var(--space-2) !important;
|
||||
}
|
||||
|
||||
.select-files-btn {
|
||||
background: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
border: none;
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-2) var(--space-3);
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.select-files-btn:hover {
|
||||
opacity: 0.9;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
/* For dark theme */
|
||||
[data-theme="dark"] .import-container {
|
||||
[data-theme="dark"] .import-drop-zone {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
|
||||
/* Responsive design for smaller screens */
|
||||
@media (max-width: 768px) {
|
||||
.thumbnail-sidebar {
|
||||
width: 160px;
|
||||
}
|
||||
|
||||
.navigation-controls {
|
||||
top: var(--space-1);
|
||||
right: var(--space-1);
|
||||
}
|
||||
|
||||
.nav-btn {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
274
static/css/components/modal/_base.css
Normal file
274
static/css/components/modal/_base.css
Normal file
@@ -0,0 +1,274 @@
|
||||
/* modal 基础样式 */
|
||||
.modal {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 48px; /* Start below the header */
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: calc(100% - 48px); /* Adjust height to exclude header */
|
||||
background: rgba(0, 0, 0, 0.2); /* 调整为更淡的半透明黑色 */
|
||||
z-index: var(--z-modal);
|
||||
overflow: auto; /* Change from hidden to auto to allow scrolling */
|
||||
}
|
||||
|
||||
/* 当模态窗口打开时,禁止body滚动 */
|
||||
body.modal-open {
|
||||
position: fixed;
|
||||
width: 100%;
|
||||
padding-right: var(--scrollbar-width, 0px); /* 补偿滚动条消失导致的页面偏移 */
|
||||
}
|
||||
|
||||
/* modal-content 样式 */
|
||||
.modal-content {
|
||||
position: relative;
|
||||
max-width: 800px;
|
||||
height: auto;
|
||||
max-height: calc(90vh - 48px); /* Adjust to account for header height */
|
||||
margin: 1rem auto; /* Keep reduced top margin */
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-base);
|
||||
padding: var(--space-3);
|
||||
border: 1px solid var(--lora-border);
|
||||
box-shadow:
|
||||
0 4px 6px -1px rgba(0, 0, 0, 0.1),
|
||||
0 2px 4px -1px rgba(0, 0, 0, 0.06),
|
||||
0 10px 15px -3px rgba(0, 0, 0, 0.05);
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden; /* 防止水平滚动条 */
|
||||
}
|
||||
|
||||
/* 当 modal 打开时锁定 body */
|
||||
body.modal-open {
|
||||
overflow: hidden !important; /* 覆盖 base.css 中的 scroll */
|
||||
padding-right: var(--scrollbar-width, 8px); /* 使用滚动条宽度作为补偿 */
|
||||
}
|
||||
|
||||
@keyframes modalFadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.modal-actions {
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
justify-content: center;
|
||||
margin-top: var(--space-3);
|
||||
}
|
||||
|
||||
.cancel-btn, .delete-btn, .exclude-btn, .confirm-btn {
|
||||
padding: 8px var(--space-2);
|
||||
border-radius: 6px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
min-width: 100px;
|
||||
}
|
||||
|
||||
.cancel-btn {
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.delete-btn {
|
||||
background: var(--lora-error);
|
||||
color: white;
|
||||
}
|
||||
|
||||
/* Style for exclude button - different from delete button */
|
||||
.exclude-btn, .confirm-btn {
|
||||
background: var(--lora-accent, #4f46e5);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.cancel-btn:hover {
|
||||
background: var(--lora-border);
|
||||
}
|
||||
|
||||
.delete-btn:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.exclude-btn:hover, .confirm-btn:hover {
|
||||
opacity: 0.9;
|
||||
background: oklch(from var(--lora-accent, #4f46e5) l c h / 85%);
|
||||
}
|
||||
|
||||
.modal-content h2 {
|
||||
color: var(--text-color);
|
||||
margin-bottom: var(--space-1);
|
||||
font-size: 1.5em;
|
||||
}
|
||||
|
||||
.close {
|
||||
position: absolute;
|
||||
top: var(--space-2);
|
||||
right: var(--space-2);
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
font-size: 1.5em;
|
||||
cursor: pointer;
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s;
|
||||
}
|
||||
|
||||
.close:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* 统一各个 section 的样式 */
|
||||
.support-section,
|
||||
.changelog-section,
|
||||
.update-info,
|
||||
.info-item,
|
||||
.path-preview {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
}
|
||||
|
||||
/* 深色主题统一样式 */
|
||||
[data-theme="dark"] .modal-content {
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .support-section,
|
||||
[data-theme="dark"] .changelog-section,
|
||||
[data-theme="dark"] .update-info,
|
||||
[data-theme="dark"] .info-item,
|
||||
[data-theme="dark"] .path-preview {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.primary-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 16px;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
border: none;
|
||||
border-radius: var(--border-radius-sm);
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.primary-btn:hover {
|
||||
background-color: oklch(from var(--lora-accent) l c h / 85%);
|
||||
color: var(--lora-text);
|
||||
}
|
||||
|
||||
/* Secondary button styles */
|
||||
.secondary-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 16px;
|
||||
background-color: var(--card-bg);
|
||||
color: var (--text-color);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.secondary-btn:hover {
|
||||
background-color: var(--border-color);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Disabled button styles */
|
||||
.primary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.secondary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.restart-required-icon {
|
||||
color: var(--lora-warning);
|
||||
margin-left: 5px;
|
||||
font-size: 0.85em;
|
||||
vertical-align: text-bottom;
|
||||
}
|
||||
|
||||
/* Dark theme specific button adjustments */
|
||||
[data-theme="dark"] .primary-btn:hover {
|
||||
background-color: oklch(from var(--lora-accent) l c h / 75%);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .secondary-btn {
|
||||
background-color: var(--lora-surface);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .secondary-btn:hover {
|
||||
background-color: oklch(35% 0.02 256 / 0.98);
|
||||
}
|
||||
|
||||
.primary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.primary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Add styles for delete preview image */
|
||||
.delete-preview {
|
||||
max-width: 150px;
|
||||
margin: 0 auto var(--space-2);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.delete-preview img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
max-height: 150px;
|
||||
object-fit: contain;
|
||||
border-radius: var(--border-radius-sm);
|
||||
}
|
||||
|
||||
.delete-info {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.delete-info h3 {
|
||||
margin-bottom: var(--space-1);
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.delete-info p {
|
||||
margin: var(--space-1) 0;
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.delete-note {
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
font-style: italic;
|
||||
margin-top: var(--space-1);
|
||||
text-align: center;
|
||||
}
|
||||
48
static/css/components/modal/delete-modal.css
Normal file
48
static/css/components/modal/delete-modal.css
Normal file
@@ -0,0 +1,48 @@
|
||||
/* Delete Modal specific styles */
|
||||
|
||||
.delete-message {
|
||||
color: var(--text-color);
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
/* Update delete modal styles */
|
||||
.delete-modal {
|
||||
display: none; /* Set initial display to none */
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
z-index: var(--z-overlay);
|
||||
}
|
||||
|
||||
/* Add new style for when modal is shown */
|
||||
.delete-modal.show {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.delete-modal-content {
|
||||
max-width: 500px;
|
||||
width: 90%;
|
||||
text-align: center;
|
||||
margin: 0 auto;
|
||||
position: relative;
|
||||
animation: modalFadeIn 0.2s ease-out;
|
||||
}
|
||||
|
||||
.delete-model-info,
|
||||
.exclude-model-info {
|
||||
/* Update info display styling */
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin: var(--space-2) 0;
|
||||
color: var(--text-color);
|
||||
word-break: break-all;
|
||||
text-align: left;
|
||||
line-height: 1.5;
|
||||
}
|
||||
72
static/css/components/modal/example-access-modal.css
Normal file
72
static/css/components/modal/example-access-modal.css
Normal file
@@ -0,0 +1,72 @@
|
||||
/* Example Access Modal */
|
||||
.example-access-modal {
|
||||
max-width: 550px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.example-access-options {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
margin: var(--space-3) 0;
|
||||
}
|
||||
|
||||
.example-option-btn {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
padding: var(--space-2);
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px solid var(--lora-border);
|
||||
background-color: var(--lora-surface);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.example-option-btn:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.example-option-btn i {
|
||||
font-size: 2em;
|
||||
margin-bottom: var(--space-1);
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.option-title {
|
||||
font-weight: 500;
|
||||
margin-bottom: 4px;
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.option-desc {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.example-option-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.example-option-btn.disabled i {
|
||||
color: var(--text-color);
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.modal-footer-note {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.7;
|
||||
margin-top: var(--space-2);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .example-option-btn:hover {
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.25);
|
||||
}
|
||||
307
static/css/components/modal/help-modal.css
Normal file
307
static/css/components/modal/help-modal.css
Normal file
@@ -0,0 +1,307 @@
|
||||
/* Help Modal styles */
|
||||
.help-modal {
|
||||
max-width: 850px;
|
||||
}
|
||||
|
||||
.help-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.modal-help-icon {
|
||||
font-size: 24px;
|
||||
color: var(--lora-accent);
|
||||
margin-right: var(--space-2);
|
||||
vertical-align: text-bottom;
|
||||
}
|
||||
|
||||
/* Tab navigation styles */
|
||||
.help-tabs {
|
||||
display: flex;
|
||||
border-bottom: 1px solid var(--lora-border);
|
||||
margin-bottom: var(--space-2);
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.tab-btn {
|
||||
padding: 8px 16px;
|
||||
background: transparent;
|
||||
border: none;
|
||||
border-bottom: 2px solid transparent;
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
transition: all 0.2s;
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.tab-btn:hover {
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.tab-btn.active {
|
||||
color: var(--lora-accent);
|
||||
border-bottom: 2px solid var(--lora-accent);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Add styles for tab with new content indicator */
|
||||
.tab-btn.has-new-content {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.tab-btn.has-new-content::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
background-color: var(--lora-accent);
|
||||
border-radius: 50%;
|
||||
animation: pulse 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% { opacity: 1; transform: scale(1); }
|
||||
50% { opacity: 0.7; transform: scale(1.1); }
|
||||
100% { opacity: 1; transform: scale(1); }
|
||||
}
|
||||
|
||||
/* Tab content styles */
|
||||
.help-content {
|
||||
padding: var(--space-1) 0;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.tab-pane {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tab-pane.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.help-text {
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
.help-text ul {
|
||||
padding-left: 20px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.help-text li {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
/* Documentation link styles */
|
||||
.docs-section {
|
||||
margin-bottom: var(--space-3);
|
||||
}
|
||||
|
||||
.docs-section h4 {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.docs-links {
|
||||
list-style-type: none;
|
||||
padding-left: var(--space-3);
|
||||
}
|
||||
|
||||
.docs-links li {
|
||||
margin-bottom: var(--space-1);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.docs-links li:before {
|
||||
content: "•";
|
||||
position: absolute;
|
||||
left: -15px;
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.docs-links a {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
transition: color 0.2s;
|
||||
}
|
||||
|
||||
.docs-links a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
/* New content badge styles */
|
||||
.new-content-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 0.7em;
|
||||
font-weight: 600;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
padding: 2px 6px;
|
||||
border-radius: 10px;
|
||||
margin-left: 8px;
|
||||
vertical-align: middle;
|
||||
animation: fadeIn 0.5s ease-in-out;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.new-content-badge.inline {
|
||||
font-size: 0.65em;
|
||||
padding: 1px 4px;
|
||||
margin-left: 6px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
/* Dark theme adjustments for new content badge */
|
||||
[data-theme="dark"] .new-content-badge {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
/* Update video list styles */
|
||||
.video-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-3);
|
||||
}
|
||||
|
||||
.video-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.video-info {
|
||||
padding: var(--space-1);
|
||||
}
|
||||
|
||||
.video-info h4 {
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.video-info p {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .tab-btn:hover {
|
||||
background-color: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
/* Update date badge styles */
|
||||
.update-date-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
font-size: 0.75em;
|
||||
font-weight: 500;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
padding: 4px 8px;
|
||||
border-radius: 12px;
|
||||
margin-left: 10px;
|
||||
vertical-align: middle;
|
||||
animation: fadeIn 0.5s ease-in-out;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.update-date-badge i {
|
||||
margin-right: 5px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(-5px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .update-date-badge {
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
/* Privacy-friendly video embed styles */
|
||||
.video-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
padding-bottom: 56.25%; /* 16:9 aspect ratio */
|
||||
height: 0;
|
||||
margin-bottom: var(--space-2);
|
||||
border-radius: var(--border-radius-sm);
|
||||
overflow: hidden;
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.video-thumbnail {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.video-thumbnail img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
transition: filter 0.2s ease;
|
||||
}
|
||||
|
||||
.video-play-overlay {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0, 0, 0, 0.5);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
/* External link button styles */
|
||||
.external-link-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 10px 20px;
|
||||
border-radius: var(--border-radius-sm);
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
background-color: var(--lora-accent);
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border: none;
|
||||
}
|
||||
|
||||
.external-link-btn:hover {
|
||||
background-color: oklch(from var(--lora-accent) l c h / 85%);
|
||||
}
|
||||
|
||||
.video-thumbnail i {
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
/* Smaller video container for the updates tab */
|
||||
.video-item .video-container {
|
||||
padding-bottom: 40%; /* Shorter height for the playlist */
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .video-container {
|
||||
background-color: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
53
static/css/components/modal/relink-civitai-modal.css
Normal file
53
static/css/components/modal/relink-civitai-modal.css
Normal file
@@ -0,0 +1,53 @@
|
||||
/* Re-link to Civitai Modal styles */
|
||||
.warning-box {
|
||||
background-color: rgba(255, 193, 7, 0.1);
|
||||
border: 1px solid rgba(255, 193, 7, 0.5);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin-bottom: var(--space-3);
|
||||
}
|
||||
|
||||
.warning-box i {
|
||||
color: var(--lora-warning);
|
||||
margin-right: var(--space-1);
|
||||
}
|
||||
|
||||
.warning-box ul {
|
||||
padding-left: 20px;
|
||||
margin: var(--space-1) 0;
|
||||
}
|
||||
|
||||
.warning-box li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
margin-bottom: var(--space-1);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.input-group input {
|
||||
padding: 8px 12px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.input-error {
|
||||
color: var(--lora-error);
|
||||
font-size: 0.9em;
|
||||
min-height: 20px;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .warning-box {
|
||||
background-color: rgba(255, 193, 7, 0.05);
|
||||
border-color: rgba(255, 193, 7, 0.3);
|
||||
}
|
||||
485
static/css/components/modal/settings-modal.css
Normal file
485
static/css/components/modal/settings-modal.css
Normal file
@@ -0,0 +1,485 @@
|
||||
/* Settings styles */
|
||||
.settings-toggle {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.settings-modal {
|
||||
max-width: 650px; /* Further increased from 600px for more space */
|
||||
}
|
||||
|
||||
/* Settings Links */
|
||||
.settings-links {
|
||||
margin-top: var(--space-3);
|
||||
padding-top: var(--space-2);
|
||||
border-top: 1px solid var(--lora-border);
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.settings-link {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
text-decoration: none;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.settings-link:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.settings-link i {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
/* Tooltip styles */
|
||||
.settings-link::after {
|
||||
content: attr(title);
|
||||
position: absolute;
|
||||
bottom: calc(100% + 8px);
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
color: white;
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 0.8em;
|
||||
white-space: nowrap;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transition: opacity 0.2s, visibility 0.2s;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.settings-link:hover::after {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
}
|
||||
|
||||
/* Responsive adjustment */
|
||||
@media (max-width: 480px) {
|
||||
.settings-links {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
}
|
||||
|
||||
/* API key input specific styles */
|
||||
.api-key-input {
|
||||
width: 100%; /* Take full width of parent */
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.api-key-input input {
|
||||
width: 100%;
|
||||
padding: 6px 40px 6px 10px; /* Add left padding */
|
||||
height: 32px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.api-key-input .toggle-visibility {
|
||||
position: absolute;
|
||||
right: 8px;
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
opacity: 0.6;
|
||||
cursor: pointer;
|
||||
padding: 4px 8px;
|
||||
}
|
||||
|
||||
.api-key-input .toggle-visibility:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.input-help {
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
margin-top: 8px; /* Space between control and help */
|
||||
line-height: 1.4;
|
||||
width: 100%; /* Full width */
|
||||
}
|
||||
|
||||
/* Settings Styles */
|
||||
.settings-section {
|
||||
margin-top: var(--space-3);
|
||||
border-top: 1px solid var(--lora-border);
|
||||
padding-top: var(--space-2);
|
||||
}
|
||||
|
||||
.settings-section h3 {
|
||||
font-size: 1.1em;
|
||||
margin-bottom: var(--space-2);
|
||||
color: var(--text-color);
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.setting-item {
|
||||
display: flex;
|
||||
flex-direction: column; /* Changed to column for help text placement */
|
||||
margin-bottom: var(--space-3); /* Increased to provide more spacing between items */
|
||||
padding: var(--space-1);
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
.setting-item:hover {
|
||||
background: rgba(0, 0, 0, 0.02);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .setting-item:hover {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
/* Control row with label and input together */
|
||||
.setting-row {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.setting-info {
|
||||
margin-bottom: 0;
|
||||
width: 35%; /* Increased from 30% to prevent wrapping */
|
||||
flex-shrink: 0; /* Prevent shrinking */
|
||||
}
|
||||
|
||||
.setting-info label {
|
||||
display: block;
|
||||
font-weight: 500;
|
||||
margin-bottom: 0;
|
||||
white-space: nowrap; /* Prevent label wrapping */
|
||||
}
|
||||
|
||||
.setting-control {
|
||||
width: 60%; /* Decreased slightly from 65% */
|
||||
margin-bottom: 0;
|
||||
display: flex;
|
||||
justify-content: flex-end; /* Right-align all controls */
|
||||
}
|
||||
|
||||
/* Select Control Styles */
|
||||
.select-control {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.select-control select {
|
||||
width: 100%;
|
||||
max-width: 100%; /* Increased from 200px */
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
/* Fix dark theme select dropdown text color */
|
||||
[data-theme="dark"] .select-control select {
|
||||
background-color: rgba(30, 30, 30, 0.9);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .select-control select option {
|
||||
background-color: #2d2d2d;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.select-control select:focus {
|
||||
border-color: var(--lora-accent);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
/* Toggle Switch */
|
||||
.toggle-switch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 50px;
|
||||
height: 24px;
|
||||
cursor: pointer;
|
||||
margin-left: auto; /* Push to right side */
|
||||
}
|
||||
|
||||
.toggle-switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.toggle-slider {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: var(--border-color);
|
||||
transition: .3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.toggle-slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: .3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
transform: translateX(26px);
|
||||
}
|
||||
|
||||
.toggle-label {
|
||||
margin-left: 60px;
|
||||
line-height: 24px;
|
||||
}
|
||||
|
||||
/* Add small animation for the toggle */
|
||||
.toggle-slider:active:before {
|
||||
width: 22px;
|
||||
}
|
||||
|
||||
/* Blur effect for NSFW content */
|
||||
.nsfw-blur {
|
||||
filter: blur(12px);
|
||||
transition: filter 0.3s ease;
|
||||
}
|
||||
|
||||
.nsfw-blur:hover {
|
||||
filter: blur(8px);
|
||||
}
|
||||
|
||||
/* Example Images Settings Styles */
|
||||
.download-buttons {
|
||||
justify-content: flex-start;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.path-control {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.path-control input[type="text"] {
|
||||
flex: 1;
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var (--text-color);
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
/* Add warning text style for settings */
|
||||
.warning-text {
|
||||
color: var(--lora-warning, #e67e22);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .warning-text {
|
||||
color: var(--lora-warning, #f39c12);
|
||||
}
|
||||
|
||||
/* Add styles for list description */
|
||||
.list-description {
|
||||
margin: 8px 0;
|
||||
padding-left: 20px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.list-description li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
/* Path Template Settings Styles */
|
||||
.template-preview {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-1);
|
||||
margin-top: 8px;
|
||||
font-family: monospace;
|
||||
font-size: 1.1em;
|
||||
color: var(--lora-accent);
|
||||
display: none;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .template-preview {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.template-preview:before {
|
||||
content: "Preview: ";
|
||||
opacity: 0.7;
|
||||
color: var(--text-color);
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
/* Base Model Mappings Styles - Updated to match other settings */
|
||||
.mappings-container {
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
background: rgba(0, 0, 0, 0.02);
|
||||
margin-top: 8px; /* Add consistent spacing */
|
||||
}
|
||||
|
||||
[data-theme="dark"] .mappings-container {
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
}
|
||||
|
||||
.add-mapping-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: var(--border-radius-xs);
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
transition: all 0.2s;
|
||||
height: 32px; /* Match other control heights */
|
||||
}
|
||||
|
||||
.add-mapping-btn:hover {
|
||||
background: oklch(from var(--lora-accent) l c h / 85%);
|
||||
}
|
||||
|
||||
.mapping-row {
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.mapping-row:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.mapping-controls {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr auto;
|
||||
gap: var(--space-1);
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.base-model-select,
|
||||
.path-value-input {
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
.path-value-input {
|
||||
height: 18px;
|
||||
}
|
||||
|
||||
.base-model-select:focus,
|
||||
.path-value-input:focus {
|
||||
border-color: var(--lora-accent);
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 2px rgba(var(--lora-accent-rgb, 79, 70, 229), 0.1);
|
||||
}
|
||||
|
||||
.remove-mapping-btn {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--lora-error);
|
||||
background: transparent;
|
||||
color: var(--lora-error);
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.remove-mapping-btn:hover {
|
||||
background: var(--lora-error);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.mapping-empty-state {
|
||||
text-align: center;
|
||||
padding: var(--space-3);
|
||||
color: var(--text-color);
|
||||
opacity: 0.6;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Responsive adjustments for mapping controls */
|
||||
@media (max-width: 768px) {
|
||||
.mapping-controls {
|
||||
grid-template-columns: 1fr;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.remove-mapping-btn {
|
||||
width: 100%;
|
||||
height: 36px;
|
||||
justify-self: stretch;
|
||||
}
|
||||
}
|
||||
|
||||
/* Dark theme specific adjustments */
|
||||
[data-theme="dark"] .base-model-select,
|
||||
[data-theme="dark"] .path-value-input {
|
||||
background-color: rgba(30, 30, 30, 0.9);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .base-model-select option {
|
||||
background-color: #2d2d2d;
|
||||
color: var(--text-color);
|
||||
}
|
||||
124
static/css/components/modal/update-modal.css
Normal file
124
static/css/components/modal/update-modal.css
Normal file
@@ -0,0 +1,124 @@
|
||||
/* Update Modal specific styles */
|
||||
.update-actions {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
align-items: stretch;
|
||||
flex-wrap: nowrap;
|
||||
}
|
||||
|
||||
.update-link {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.update-link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
/* Update progress styles */
|
||||
.update-progress {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .update-progress {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
|
||||
.progress-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-1);
|
||||
}
|
||||
|
||||
.progress-text {
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.update-progress-bar {
|
||||
width: 100%;
|
||||
height: 8px;
|
||||
background-color: rgba(0, 0, 0, 0.1);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .update-progress-bar {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background-color: var(--lora-accent);
|
||||
width: 0%;
|
||||
transition: width 0.3s ease;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Update button states */
|
||||
#updateBtn {
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#updateBtn.updating {
|
||||
background-color: var(--lora-warning);
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
#updateBtn.success {
|
||||
background-color: var(--lora-success);
|
||||
}
|
||||
|
||||
#updateBtn.error {
|
||||
background-color: var(--lora-error);
|
||||
}
|
||||
|
||||
/* Add styles for markdown elements in changelog */
|
||||
.changelog-item ul {
|
||||
padding-left: 20px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.changelog-item li {
|
||||
margin-bottom: 6px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
.changelog-item strong {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.changelog-item em {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.changelog-item code {
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
padding: 2px 4px;
|
||||
border-radius: 3px;
|
||||
font-family: monospace;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .changelog-item code {
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.changelog-item a {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.changelog-item a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
@@ -124,6 +124,43 @@
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Keyboard shortcut indicator styling */
|
||||
.shortcut-key {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin-left: 6px;
|
||||
min-width: 16px;
|
||||
height: 16px;
|
||||
padding: 0 3px;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
line-height: 1;
|
||||
border-radius: var(--border-radius-xs);
|
||||
background-color: var(--shortcut-bg);
|
||||
border: 1px solid var(--shortcut-border);
|
||||
color: var(--shortcut-text);
|
||||
vertical-align: middle;
|
||||
opacity: 0.8;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.control-group button:hover .shortcut-key {
|
||||
opacity: 1;
|
||||
background-color: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.2);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .shortcut-key {
|
||||
--shortcut-bg: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.15);
|
||||
--shortcut-border: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.3);
|
||||
}
|
||||
|
||||
/* Ensure correct vertical alignment for text+shortcut */
|
||||
.control-group button span {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
/* Select dropdown styling */
|
||||
.control-group select {
|
||||
min-width: 100px;
|
||||
@@ -145,6 +182,31 @@
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Style for optgroups */
|
||||
.control-group select optgroup {
|
||||
font-weight: 600;
|
||||
font-style: normal;
|
||||
color: var(--text-color);
|
||||
background-color: var(--card-bg);
|
||||
}
|
||||
|
||||
.control-group select option {
|
||||
padding: 4px 8px;
|
||||
background-color: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Dark theme optgroup styling */
|
||||
[data-theme="dark"] .control-group select optgroup {
|
||||
background-color: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .control-group select option {
|
||||
background-color: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.control-group select:hover {
|
||||
border-color: var(--lora-accent);
|
||||
background-color: var(--bg-color);
|
||||
|
||||
@@ -7,7 +7,14 @@
|
||||
/* Import Components */
|
||||
@import 'components/header.css';
|
||||
@import 'components/card.css';
|
||||
@import 'components/modal.css';
|
||||
@import 'components/modal/_base.css';
|
||||
@import 'components/modal/delete-modal.css';
|
||||
@import 'components/modal/update-modal.css';
|
||||
@import 'components/modal/settings-modal.css';
|
||||
@import 'components/modal/help-modal.css';
|
||||
@import 'components/modal/relink-civitai-modal.css';
|
||||
@import 'components/modal/example-access-modal.css';
|
||||
@import 'components/modal/support-modal.css';
|
||||
@import 'components/download-modal.css';
|
||||
@import 'components/toast.css';
|
||||
@import 'components/loading.css';
|
||||
@@ -20,7 +27,6 @@
|
||||
@import 'components/lora-modal/showcase.css';
|
||||
@import 'components/lora-modal/triggerwords.css';
|
||||
@import 'components/shared/edit-metadata.css';
|
||||
@import 'components/support-modal.css';
|
||||
@import 'components/search-filter.css';
|
||||
@import 'components/bulk.css';
|
||||
@import 'components/shared.css';
|
||||
|
||||
168
static/js/api/apiConfig.js
Normal file
168
static/js/api/apiConfig.js
Normal file
@@ -0,0 +1,168 @@
|
||||
import { state } from '../state/index.js';
|
||||
|
||||
/**
|
||||
* API Configuration
|
||||
* Centralized configuration for all model types and their endpoints
|
||||
*/
|
||||
|
||||
// Model type definitions
|
||||
export const MODEL_TYPES = {
|
||||
LORA: 'loras',
|
||||
CHECKPOINT: 'checkpoints',
|
||||
EMBEDDING: 'embeddings' // Future model type
|
||||
};
|
||||
|
||||
// Base API configuration for each model type
|
||||
export const MODEL_CONFIG = {
|
||||
[MODEL_TYPES.LORA]: {
|
||||
displayName: 'LoRA',
|
||||
singularName: 'lora',
|
||||
defaultPageSize: 100,
|
||||
supportsLetterFilter: true,
|
||||
supportsBulkOperations: true,
|
||||
supportsMove: true,
|
||||
templateName: 'loras.html'
|
||||
},
|
||||
[MODEL_TYPES.CHECKPOINT]: {
|
||||
displayName: 'Checkpoint',
|
||||
singularName: 'checkpoint',
|
||||
defaultPageSize: 100,
|
||||
supportsLetterFilter: false,
|
||||
supportsBulkOperations: true,
|
||||
supportsMove: false,
|
||||
templateName: 'checkpoints.html'
|
||||
},
|
||||
[MODEL_TYPES.EMBEDDING]: {
|
||||
displayName: 'Embedding',
|
||||
singularName: 'embedding',
|
||||
defaultPageSize: 100,
|
||||
supportsLetterFilter: true,
|
||||
supportsBulkOperations: true,
|
||||
supportsMove: true,
|
||||
templateName: 'embeddings.html'
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Generate API endpoints for a given model type
|
||||
* @param {string} modelType - The model type (e.g., 'loras', 'checkpoints')
|
||||
* @returns {Object} Object containing all API endpoints for the model type
|
||||
*/
|
||||
export function getApiEndpoints(modelType) {
|
||||
if (!Object.values(MODEL_TYPES).includes(modelType)) {
|
||||
throw new Error(`Invalid model type: ${modelType}`);
|
||||
}
|
||||
|
||||
return {
|
||||
// Base CRUD operations
|
||||
list: `/api/${modelType}`,
|
||||
delete: `/api/${modelType}/delete`,
|
||||
exclude: `/api/${modelType}/exclude`,
|
||||
rename: `/api/${modelType}/rename`,
|
||||
save: `/api/${modelType}/save-metadata`,
|
||||
|
||||
// Bulk operations
|
||||
bulkDelete: `/api/${modelType}/bulk-delete`,
|
||||
|
||||
// CivitAI integration
|
||||
fetchCivitai: `/api/${modelType}/fetch-civitai`,
|
||||
fetchAllCivitai: `/api/${modelType}/fetch-all-civitai`,
|
||||
relinkCivitai: `/api/${modelType}/relink-civitai`,
|
||||
civitaiVersions: `/api/${modelType}/civitai/versions`,
|
||||
|
||||
// Preview management
|
||||
replacePreview: `/api/${modelType}/replace-preview`,
|
||||
|
||||
// Query operations
|
||||
scan: `/api/${modelType}/scan`,
|
||||
topTags: `/api/${modelType}/top-tags`,
|
||||
baseModels: `/api/${modelType}/base-models`,
|
||||
roots: `/api/${modelType}/roots`,
|
||||
folders: `/api/${modelType}/folders`,
|
||||
duplicates: `/api/${modelType}/find-duplicates`,
|
||||
conflicts: `/api/${modelType}/find-filename-conflicts`,
|
||||
verify: `/api/${modelType}/verify-duplicates`,
|
||||
|
||||
// Model-specific endpoints (will be merged with specific configs)
|
||||
specific: {}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Model-specific endpoint configurations
|
||||
*/
|
||||
export const MODEL_SPECIFIC_ENDPOINTS = {
|
||||
[MODEL_TYPES.LORA]: {
|
||||
letterCounts: `/api/${MODEL_TYPES.LORA}/letter-counts`,
|
||||
notes: `/api/${MODEL_TYPES.LORA}/get-notes`,
|
||||
triggerWords: `/api/${MODEL_TYPES.LORA}/get-trigger-words`,
|
||||
previewUrl: `/api/${MODEL_TYPES.LORA}/preview-url`,
|
||||
civitaiUrl: `/api/${MODEL_TYPES.LORA}/civitai-url`,
|
||||
modelDescription: `/api/${MODEL_TYPES.LORA}/model-description`,
|
||||
moveModel: `/api/${MODEL_TYPES.LORA}/move_model`,
|
||||
moveBulk: `/api/${MODEL_TYPES.LORA}/move_models_bulk`,
|
||||
getTriggerWordsPost: `/api/${MODEL_TYPES.LORA}/get_trigger_words`,
|
||||
civitaiModelByVersion: `/api/${MODEL_TYPES.LORA}/civitai/model/version`,
|
||||
civitaiModelByHash: `/api/${MODEL_TYPES.LORA}/civitai/model/hash`,
|
||||
},
|
||||
[MODEL_TYPES.CHECKPOINT]: {
|
||||
info: `/api/${MODEL_TYPES.CHECKPOINT}/info`,
|
||||
},
|
||||
[MODEL_TYPES.EMBEDDING]: {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get complete API configuration for a model type
|
||||
* @param {string} modelType - The model type
|
||||
* @returns {Object} Complete API configuration
|
||||
*/
|
||||
export function getCompleteApiConfig(modelType) {
|
||||
const baseEndpoints = getApiEndpoints(modelType);
|
||||
const specificEndpoints = MODEL_SPECIFIC_ENDPOINTS[modelType] || {};
|
||||
const config = MODEL_CONFIG[modelType];
|
||||
|
||||
return {
|
||||
modelType,
|
||||
config,
|
||||
endpoints: {
|
||||
...baseEndpoints,
|
||||
specific: specificEndpoints
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate if a model type is supported
|
||||
* @param {string} modelType - The model type to validate
|
||||
* @returns {boolean} True if valid, false otherwise
|
||||
*/
|
||||
export function isValidModelType(modelType) {
|
||||
return Object.values(MODEL_TYPES).includes(modelType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model type from current page or explicit parameter
|
||||
* @param {string} [explicitType] - Explicitly provided model type
|
||||
* @returns {string} The model type
|
||||
*/
|
||||
export function getCurrentModelType(explicitType = null) {
|
||||
if (explicitType && isValidModelType(explicitType)) {
|
||||
return explicitType;
|
||||
}
|
||||
|
||||
return state.currentPageType || MODEL_TYPES.LORA;
|
||||
}
|
||||
|
||||
// Download API endpoints (shared across all model types)
|
||||
export const DOWNLOAD_ENDPOINTS = {
|
||||
download: '/api/download-model',
|
||||
downloadGet: '/api/download-model-get',
|
||||
cancelGet: '/api/cancel-download-get',
|
||||
progress: '/api/download-progress'
|
||||
};
|
||||
|
||||
// WebSocket endpoints
|
||||
export const WS_ENDPOINTS = {
|
||||
fetchProgress: '/ws/fetch-progress'
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,165 +0,0 @@
|
||||
import {
|
||||
fetchModelsPage,
|
||||
resetAndReloadWithVirtualScroll,
|
||||
loadMoreWithVirtualScroll,
|
||||
refreshModels as baseRefreshModels,
|
||||
deleteModel as baseDeleteModel,
|
||||
replaceModelPreview,
|
||||
fetchCivitaiMetadata,
|
||||
refreshSingleModelMetadata,
|
||||
excludeModel as baseExcludeModel
|
||||
} from './baseModelApi.js';
|
||||
import { state } from '../state/index.js';
|
||||
|
||||
/**
|
||||
* Fetch checkpoints with pagination for virtual scrolling
|
||||
* @param {number} page - Page number to fetch
|
||||
* @param {number} pageSize - Number of items per page
|
||||
* @returns {Promise<Object>} Object containing items, total count, and pagination info
|
||||
*/
|
||||
export async function fetchCheckpointsPage(page = 1, pageSize = 100) {
|
||||
return fetchModelsPage({
|
||||
modelType: 'checkpoint',
|
||||
page,
|
||||
pageSize,
|
||||
endpoint: '/api/checkpoints'
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load more checkpoints with pagination - updated to work with VirtualScroller
|
||||
* @param {boolean} resetPage - Whether to reset to the first page
|
||||
* @param {boolean} updateFolders - Whether to update folder tags
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
export async function loadMoreCheckpoints(resetPage = false, updateFolders = false) {
|
||||
return loadMoreWithVirtualScroll({
|
||||
modelType: 'checkpoint',
|
||||
resetPage,
|
||||
updateFolders,
|
||||
fetchPageFunction: fetchCheckpointsPage
|
||||
});
|
||||
}
|
||||
|
||||
// Reset and reload checkpoints
|
||||
export async function resetAndReload(updateFolders = false) {
|
||||
return resetAndReloadWithVirtualScroll({
|
||||
modelType: 'checkpoint',
|
||||
updateFolders,
|
||||
fetchPageFunction: fetchCheckpointsPage
|
||||
});
|
||||
}
|
||||
|
||||
// Refresh checkpoints
|
||||
export async function refreshCheckpoints(fullRebuild = false) {
|
||||
return baseRefreshModels({
|
||||
modelType: 'checkpoint',
|
||||
scanEndpoint: '/api/checkpoints/scan',
|
||||
resetAndReloadFunction: resetAndReload,
|
||||
fullRebuild: fullRebuild
|
||||
});
|
||||
}
|
||||
|
||||
// Delete a checkpoint
|
||||
export function deleteCheckpoint(filePath) {
|
||||
return baseDeleteModel(filePath, 'checkpoint');
|
||||
}
|
||||
|
||||
// Replace checkpoint preview
|
||||
export function replaceCheckpointPreview(filePath) {
|
||||
return replaceModelPreview(filePath, 'checkpoint');
|
||||
}
|
||||
|
||||
// Fetch metadata from Civitai for checkpoints
|
||||
export async function fetchCivitai() {
|
||||
return fetchCivitaiMetadata({
|
||||
modelType: 'checkpoint',
|
||||
fetchEndpoint: '/api/checkpoints/fetch-all-civitai',
|
||||
resetAndReloadFunction: resetAndReload
|
||||
});
|
||||
}
|
||||
|
||||
// Refresh single checkpoint metadata
|
||||
export async function refreshSingleCheckpointMetadata(filePath) {
|
||||
await refreshSingleModelMetadata(filePath, 'checkpoint');
|
||||
}
|
||||
|
||||
/**
|
||||
* Save model metadata to the server
|
||||
* @param {string} filePath - Path to the model file
|
||||
* @param {Object} data - Metadata to save
|
||||
* @returns {Promise} - Promise that resolves with the server response
|
||||
*/
|
||||
export async function saveModelMetadata(filePath, data) {
|
||||
try {
|
||||
// Show loading indicator
|
||||
state.loadingManager.showSimpleLoading('Saving metadata...');
|
||||
|
||||
const response = await fetch('/api/checkpoints/save-metadata', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
file_path: filePath,
|
||||
...data
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to save metadata');
|
||||
}
|
||||
|
||||
// Update the virtual scroller with the new metadata
|
||||
state.virtualScroller.updateSingleItem(filePath, data);
|
||||
|
||||
return response.json();
|
||||
} finally {
|
||||
// Always hide the loading indicator when done
|
||||
state.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Exclude a checkpoint model from being shown in the UI
|
||||
* @param {string} filePath - File path of the checkpoint to exclude
|
||||
* @returns {Promise<boolean>} Promise resolving to success status
|
||||
*/
|
||||
export function excludeCheckpoint(filePath) {
|
||||
return baseExcludeModel(filePath, 'checkpoint');
|
||||
}
|
||||
|
||||
/**
|
||||
* Rename a checkpoint file
|
||||
* @param {string} filePath - Current file path
|
||||
* @param {string} newFileName - New file name (without path)
|
||||
* @returns {Promise<Object>} - Promise that resolves with the server response
|
||||
*/
|
||||
export async function renameCheckpointFile(filePath, newFileName) {
|
||||
try {
|
||||
// Show loading indicator
|
||||
state.loadingManager.showSimpleLoading('Renaming checkpoint file...');
|
||||
|
||||
const response = await fetch('/api/checkpoints/rename', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
file_path: filePath,
|
||||
new_file_name: newFileName
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server returned ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('Error renaming checkpoint file:', error);
|
||||
throw error;
|
||||
} finally {
|
||||
state.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
import {
|
||||
fetchModelsPage,
|
||||
resetAndReloadWithVirtualScroll,
|
||||
loadMoreWithVirtualScroll,
|
||||
refreshModels as baseRefreshModels,
|
||||
deleteModel as baseDeleteModel,
|
||||
replaceModelPreview,
|
||||
fetchCivitaiMetadata,
|
||||
refreshSingleModelMetadata,
|
||||
excludeModel as baseExcludeModel
|
||||
} from './baseModelApi.js';
|
||||
import { state } from '../state/index.js';
|
||||
|
||||
/**
|
||||
* Save model metadata to the server
|
||||
* @param {string} filePath - File path
|
||||
* @param {Object} data - Data to save
|
||||
* @returns {Promise} Promise of the save operation
|
||||
*/
|
||||
export async function saveModelMetadata(filePath, data) {
|
||||
try {
|
||||
// Show loading indicator
|
||||
state.loadingManager.showSimpleLoading('Saving metadata...');
|
||||
|
||||
const response = await fetch('/api/loras/save-metadata', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
file_path: filePath,
|
||||
...data
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to save metadata');
|
||||
}
|
||||
|
||||
// Update the virtual scroller with the new data
|
||||
state.virtualScroller.updateSingleItem(filePath, data);
|
||||
|
||||
return response.json();
|
||||
} finally {
|
||||
// Always hide the loading indicator when done
|
||||
state.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Exclude a lora model from being shown in the UI
|
||||
* @param {string} filePath - File path of the model to exclude
|
||||
* @returns {Promise<boolean>} Promise resolving to success status
|
||||
*/
|
||||
export async function excludeLora(filePath) {
|
||||
return baseExcludeModel(filePath, 'lora');
|
||||
}
|
||||
|
||||
/**
|
||||
* Load more loras with pagination - updated to work with VirtualScroller
|
||||
* @param {boolean} resetPage - Whether to reset to the first page
|
||||
* @param {boolean} updateFolders - Whether to update folder tags
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
export async function loadMoreLoras(resetPage = false, updateFolders = false) {
|
||||
return loadMoreWithVirtualScroll({
|
||||
modelType: 'lora',
|
||||
resetPage,
|
||||
updateFolders,
|
||||
fetchPageFunction: fetchLorasPage
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch loras with pagination for virtual scrolling
|
||||
* @param {number} page - Page number to fetch
|
||||
* @param {number} pageSize - Number of items per page
|
||||
* @returns {Promise<Object>} Object containing items, total count, and pagination info
|
||||
*/
|
||||
export async function fetchLorasPage(page = 1, pageSize = 100) {
|
||||
return fetchModelsPage({
|
||||
modelType: 'lora',
|
||||
page,
|
||||
pageSize,
|
||||
endpoint: '/api/loras'
|
||||
});
|
||||
}
|
||||
|
||||
export async function fetchCivitai() {
|
||||
return fetchCivitaiMetadata({
|
||||
modelType: 'lora',
|
||||
fetchEndpoint: '/api/fetch-all-civitai',
|
||||
resetAndReloadFunction: resetAndReload
|
||||
});
|
||||
}
|
||||
|
||||
export async function deleteModel(filePath) {
|
||||
return baseDeleteModel(filePath, 'lora');
|
||||
}
|
||||
|
||||
export async function replacePreview(filePath) {
|
||||
return replaceModelPreview(filePath, 'lora');
|
||||
}
|
||||
|
||||
export async function resetAndReload(updateFolders = false) {
|
||||
return resetAndReloadWithVirtualScroll({
|
||||
modelType: 'lora',
|
||||
updateFolders,
|
||||
fetchPageFunction: fetchLorasPage
|
||||
});
|
||||
}
|
||||
|
||||
export async function refreshLoras(fullRebuild = false) {
|
||||
return baseRefreshModels({
|
||||
modelType: 'lora',
|
||||
scanEndpoint: '/api/loras/scan',
|
||||
resetAndReloadFunction: resetAndReload,
|
||||
fullRebuild: fullRebuild
|
||||
});
|
||||
}
|
||||
|
||||
export async function refreshSingleLoraMetadata(filePath) {
|
||||
await refreshSingleModelMetadata(filePath, 'lora');
|
||||
}
|
||||
|
||||
export async function fetchModelDescription(modelId, filePath) {
|
||||
try {
|
||||
const response = await fetch(`/api/lora-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch model description: ${response.statusText}`);
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('Error fetching model description:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Rename a LoRA file
|
||||
* @param {string} filePath - Current file path
|
||||
* @param {string} newFileName - New file name (without path)
|
||||
* @returns {Promise<Object>} - Promise that resolves with the server response
|
||||
*/
|
||||
export async function renameLoraFile(filePath, newFileName) {
|
||||
try {
|
||||
// Show loading indicator
|
||||
state.loadingManager.showSimpleLoading('Renaming LoRA file...');
|
||||
|
||||
const response = await fetch('/api/loras/rename', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
file_path: filePath,
|
||||
new_file_name: newFileName
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server returned ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('Error renaming LoRA file:', error);
|
||||
throw error;
|
||||
} finally {
|
||||
// Hide loading indicator
|
||||
state.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,4 @@
|
||||
import { RecipeCard } from '../components/RecipeCard.js';
|
||||
import {
|
||||
resetAndReloadWithVirtualScroll,
|
||||
loadMoreWithVirtualScroll
|
||||
} from './baseModelApi.js';
|
||||
import { state, getCurrentPageState } from '../state/index.js';
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
|
||||
@@ -98,6 +94,98 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset and reload models using virtual scrolling
|
||||
* @param {Object} options - Operation options
|
||||
* @returns {Promise<Object>} The fetch result
|
||||
*/
|
||||
export async function resetAndReloadWithVirtualScroll(options = {}) {
|
||||
const {
|
||||
modelType = 'lora',
|
||||
updateFolders = false,
|
||||
fetchPageFunction
|
||||
} = options;
|
||||
|
||||
const pageState = getCurrentPageState();
|
||||
|
||||
try {
|
||||
pageState.isLoading = true;
|
||||
|
||||
// Reset page counter
|
||||
pageState.currentPage = 1;
|
||||
|
||||
// Fetch the first page
|
||||
const result = await fetchPageFunction(1, pageState.pageSize || 50);
|
||||
|
||||
// Update the virtual scroller
|
||||
state.virtualScroller.refreshWithData(
|
||||
result.items,
|
||||
result.totalItems,
|
||||
result.hasMore
|
||||
);
|
||||
|
||||
// Update state
|
||||
pageState.hasMore = result.hasMore;
|
||||
pageState.currentPage = 2; // Next page will be 2
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.error(`Error reloading ${modelType}s:`, error);
|
||||
showToast(`Failed to reload ${modelType}s: ${error.message}`, 'error');
|
||||
throw error;
|
||||
} finally {
|
||||
pageState.isLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load more models using virtual scrolling
|
||||
* @param {Object} options - Operation options
|
||||
* @returns {Promise<Object>} The fetch result
|
||||
*/
|
||||
export async function loadMoreWithVirtualScroll(options = {}) {
|
||||
const {
|
||||
modelType = 'lora',
|
||||
resetPage = false,
|
||||
updateFolders = false,
|
||||
fetchPageFunction
|
||||
} = options;
|
||||
|
||||
const pageState = getCurrentPageState();
|
||||
|
||||
try {
|
||||
// Start loading state
|
||||
pageState.isLoading = true;
|
||||
|
||||
// Reset to first page if requested
|
||||
if (resetPage) {
|
||||
pageState.currentPage = 1;
|
||||
}
|
||||
|
||||
// Fetch the first page of data
|
||||
const result = await fetchPageFunction(pageState.currentPage, pageState.pageSize || 50);
|
||||
|
||||
// Update virtual scroller with the new data
|
||||
state.virtualScroller.refreshWithData(
|
||||
result.items,
|
||||
result.totalItems,
|
||||
result.hasMore
|
||||
);
|
||||
|
||||
// Update state
|
||||
pageState.hasMore = result.hasMore;
|
||||
pageState.currentPage = 2; // Next page to load would be 2
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.error(`Error loading ${modelType}s:`, error);
|
||||
showToast(`Failed to load ${modelType}s: ${error.message}`, 'error');
|
||||
throw error;
|
||||
} finally {
|
||||
pageState.isLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset and reload recipes using virtual scrolling
|
||||
* @param {boolean} updateFolders - Whether to update folder tags
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
import { appCore } from './core.js';
|
||||
import { confirmDelete, closeDeleteModal, confirmExclude, closeExcludeModal } from './utils/modalUtils.js';
|
||||
import { createPageControls } from './components/controls/index.js';
|
||||
import { loadMoreCheckpoints } from './api/checkpointApi.js';
|
||||
import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js';
|
||||
import { CheckpointContextMenu } from './components/ContextMenu/index.js';
|
||||
import { ModelDuplicatesManager } from './components/ModelDuplicatesManager.js';
|
||||
import { MODEL_TYPES } from './api/apiConfig.js';
|
||||
|
||||
// Initialize the Checkpoints page
|
||||
class CheckpointsPageManager {
|
||||
constructor() {
|
||||
// Initialize page controls
|
||||
this.pageControls = createPageControls('checkpoints');
|
||||
|
||||
// Initialize checkpoint download manager
|
||||
window.checkpointDownloadManager = new CheckpointDownloadManager();
|
||||
this.pageControls = createPageControls(MODEL_TYPES.CHECKPOINT);
|
||||
|
||||
// Initialize the ModelDuplicatesManager
|
||||
this.duplicatesManager = new ModelDuplicatesManager(this, 'checkpoints');
|
||||
this.duplicatesManager = new ModelDuplicatesManager(this, MODEL_TYPES.CHECKPOINT);
|
||||
|
||||
// Expose only necessary functions to global scope
|
||||
this._exposeRequiredGlobalFunctions();
|
||||
@@ -29,11 +25,6 @@ class CheckpointsPageManager {
|
||||
window.confirmExclude = confirmExclude;
|
||||
window.closeExcludeModal = closeExcludeModal;
|
||||
|
||||
// Add loadCheckpoints function to window for FilterManager compatibility
|
||||
window.checkpointManager = {
|
||||
loadCheckpoints: (reset) => loadMoreCheckpoints(reset)
|
||||
};
|
||||
|
||||
// Expose duplicates manager
|
||||
window.modelDuplicatesManager = this.duplicatesManager;
|
||||
}
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
import { showToast, copyToClipboard, openExampleImagesFolder, openCivitai } from '../utils/uiHelpers.js';
|
||||
import { state } from '../state/index.js';
|
||||
import { showCheckpointModal } from './checkpointModal/index.js';
|
||||
import { NSFW_LEVELS } from '../utils/constants.js';
|
||||
import { replaceCheckpointPreview as apiReplaceCheckpointPreview, saveModelMetadata } from '../api/checkpointApi.js';
|
||||
import { showDeleteModal } from '../utils/modalUtils.js';
|
||||
|
||||
// Add a global event delegation handler
|
||||
export function setupCheckpointCardEventDelegation() {
|
||||
const gridElement = document.getElementById('checkpointGrid');
|
||||
if (!gridElement) return;
|
||||
|
||||
// Remove any existing event listener to prevent duplication
|
||||
gridElement.removeEventListener('click', handleCheckpointCardEvent);
|
||||
|
||||
// Add the event delegation handler
|
||||
gridElement.addEventListener('click', handleCheckpointCardEvent);
|
||||
}
|
||||
|
||||
// Event delegation handler for all checkpoint card events
|
||||
function handleCheckpointCardEvent(event) {
|
||||
// Find the closest card element
|
||||
const card = event.target.closest('.lora-card');
|
||||
if (!card) return;
|
||||
|
||||
// Handle specific elements within the card
|
||||
if (event.target.closest('.toggle-blur-btn')) {
|
||||
event.stopPropagation();
|
||||
toggleBlurContent(card);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.show-content-btn')) {
|
||||
event.stopPropagation();
|
||||
showBlurredContent(card);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.fa-star')) {
|
||||
event.stopPropagation();
|
||||
toggleFavorite(card);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.fa-globe')) {
|
||||
event.stopPropagation();
|
||||
if (card.dataset.from_civitai === 'true') {
|
||||
openCivitai(card.dataset.filepath);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.fa-copy')) {
|
||||
event.stopPropagation();
|
||||
copyCheckpointName(card);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.fa-trash')) {
|
||||
event.stopPropagation();
|
||||
showDeleteModal(card.dataset.filepath);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.fa-image')) {
|
||||
event.stopPropagation();
|
||||
replaceCheckpointPreview(card.dataset.filepath);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.target.closest('.fa-folder-open')) {
|
||||
event.stopPropagation();
|
||||
openExampleImagesFolder(card.dataset.sha256);
|
||||
return;
|
||||
}
|
||||
|
||||
// If no specific element was clicked, handle the card click (show modal)
|
||||
showCheckpointModalFromCard(card);
|
||||
}
|
||||
|
||||
// Helper functions for event handling
|
||||
function toggleBlurContent(card) {
|
||||
const preview = card.querySelector('.card-preview');
|
||||
const isBlurred = preview.classList.toggle('blurred');
|
||||
const icon = card.querySelector('.toggle-blur-btn i');
|
||||
|
||||
// Update the icon based on blur state
|
||||
if (isBlurred) {
|
||||
icon.className = 'fas fa-eye';
|
||||
} else {
|
||||
icon.className = 'fas fa-eye-slash';
|
||||
}
|
||||
|
||||
// Toggle the overlay visibility
|
||||
const overlay = card.querySelector('.nsfw-overlay');
|
||||
if (overlay) {
|
||||
overlay.style.display = isBlurred ? 'flex' : 'none';
|
||||
}
|
||||
}
|
||||
|
||||
function showBlurredContent(card) {
|
||||
const preview = card.querySelector('.card-preview');
|
||||
preview.classList.remove('blurred');
|
||||
|
||||
// Update the toggle button icon
|
||||
const toggleBtn = card.querySelector('.toggle-blur-btn');
|
||||
if (toggleBtn) {
|
||||
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
|
||||
}
|
||||
|
||||
// Hide the overlay
|
||||
const overlay = card.querySelector('.nsfw-overlay');
|
||||
if (overlay) {
|
||||
overlay.style.display = 'none';
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleFavorite(card) {
|
||||
const starIcon = card.querySelector('.fa-star');
|
||||
const isFavorite = starIcon.classList.contains('fas');
|
||||
const newFavoriteState = !isFavorite;
|
||||
|
||||
try {
|
||||
// Save the new favorite state to the server
|
||||
await saveModelMetadata(card.dataset.filepath, {
|
||||
favorite: newFavoriteState
|
||||
});
|
||||
|
||||
if (newFavoriteState) {
|
||||
showToast('Added to favorites', 'success');
|
||||
} else {
|
||||
showToast('Removed from favorites', 'success');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to update favorite status:', error);
|
||||
showToast('Failed to update favorite status', 'error');
|
||||
}
|
||||
}
|
||||
|
||||
async function copyCheckpointName(card) {
|
||||
const checkpointName = card.dataset.file_name;
|
||||
|
||||
try {
|
||||
await copyToClipboard(checkpointName, 'Checkpoint name copied');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
}
|
||||
}
|
||||
|
||||
function showCheckpointModalFromCard(card) {
|
||||
// Get the page-specific previewVersions map
|
||||
const previewVersions = state.pages.checkpoints.previewVersions || new Map();
|
||||
const version = previewVersions.get(card.dataset.filepath);
|
||||
const previewUrl = card.dataset.preview_url || '/loras_static/images/no-preview.png';
|
||||
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
|
||||
|
||||
// Show checkpoint details modal
|
||||
const checkpointMeta = {
|
||||
sha256: card.dataset.sha256,
|
||||
file_path: card.dataset.filepath,
|
||||
model_name: card.dataset.name,
|
||||
file_name: card.dataset.file_name,
|
||||
folder: card.dataset.folder,
|
||||
modified: card.dataset.modified,
|
||||
file_size: parseInt(card.dataset.file_size || '0'),
|
||||
from_civitai: card.dataset.from_civitai === 'true',
|
||||
base_model: card.dataset.base_model,
|
||||
notes: card.dataset.notes || '',
|
||||
preview_url: versionedPreviewUrl,
|
||||
// Parse civitai metadata from the card's dataset
|
||||
civitai: (() => {
|
||||
try {
|
||||
return JSON.parse(card.dataset.meta || '{}');
|
||||
} catch (e) {
|
||||
console.error('Failed to parse civitai metadata:', e);
|
||||
return {}; // Return empty object on error
|
||||
}
|
||||
})(),
|
||||
tags: (() => {
|
||||
try {
|
||||
return JSON.parse(card.dataset.tags || '[]');
|
||||
} catch (e) {
|
||||
console.error('Failed to parse tags:', e);
|
||||
return []; // Return empty array on error
|
||||
}
|
||||
})(),
|
||||
modelDescription: card.dataset.modelDescription || ''
|
||||
};
|
||||
showCheckpointModal(checkpointMeta);
|
||||
}
|
||||
|
||||
function replaceCheckpointPreview(filePath) {
|
||||
if (window.replaceCheckpointPreview) {
|
||||
window.replaceCheckpointPreview(filePath);
|
||||
} else {
|
||||
apiReplaceCheckpointPreview(filePath);
|
||||
}
|
||||
}
|
||||
|
||||
export function createCheckpointCard(checkpoint) {
|
||||
const card = document.createElement('div');
|
||||
card.className = 'lora-card'; // Reuse the same class for styling
|
||||
card.dataset.sha256 = checkpoint.sha256;
|
||||
card.dataset.filepath = checkpoint.file_path;
|
||||
card.dataset.name = checkpoint.model_name;
|
||||
card.dataset.file_name = checkpoint.file_name;
|
||||
card.dataset.folder = checkpoint.folder;
|
||||
card.dataset.modified = checkpoint.modified;
|
||||
card.dataset.file_size = checkpoint.file_size;
|
||||
card.dataset.from_civitai = checkpoint.from_civitai;
|
||||
card.dataset.notes = checkpoint.notes || '';
|
||||
card.dataset.base_model = checkpoint.base_model || 'Unknown';
|
||||
card.dataset.favorite = checkpoint.favorite ? 'true' : 'false';
|
||||
|
||||
// Store metadata if available
|
||||
if (checkpoint.civitai) {
|
||||
card.dataset.meta = JSON.stringify(checkpoint.civitai || {});
|
||||
}
|
||||
|
||||
// Store tags if available
|
||||
if (checkpoint.tags && Array.isArray(checkpoint.tags)) {
|
||||
card.dataset.tags = JSON.stringify(checkpoint.tags);
|
||||
}
|
||||
|
||||
if (checkpoint.modelDescription) {
|
||||
card.dataset.modelDescription = checkpoint.modelDescription;
|
||||
}
|
||||
|
||||
// Store NSFW level if available
|
||||
const nsfwLevel = checkpoint.preview_nsfw_level !== undefined ? checkpoint.preview_nsfw_level : 0;
|
||||
card.dataset.nsfwLevel = nsfwLevel;
|
||||
|
||||
// Determine if the preview should be blurred based on NSFW level and user settings
|
||||
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
|
||||
if (shouldBlur) {
|
||||
card.classList.add('nsfw-content');
|
||||
}
|
||||
|
||||
// Determine preview URL
|
||||
const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png';
|
||||
|
||||
// Get the page-specific previewVersions map
|
||||
const previewVersions = state.pages.checkpoints.previewVersions || new Map();
|
||||
const version = previewVersions.get(checkpoint.file_path);
|
||||
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
|
||||
|
||||
// Determine NSFW warning text based on level
|
||||
let nsfwText = "Mature Content";
|
||||
if (nsfwLevel >= NSFW_LEVELS.XXX) {
|
||||
nsfwText = "XXX-rated Content";
|
||||
} else if (nsfwLevel >= NSFW_LEVELS.X) {
|
||||
nsfwText = "X-rated Content";
|
||||
} else if (nsfwLevel >= NSFW_LEVELS.R) {
|
||||
nsfwText = "R-rated Content";
|
||||
}
|
||||
|
||||
// Check if autoplayOnHover is enabled for video previews
|
||||
const autoplayOnHover = state.global?.settings?.autoplayOnHover || false;
|
||||
const isVideo = previewUrl.endsWith('.mp4');
|
||||
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
||||
|
||||
// Get favorite status from checkpoint data
|
||||
const isFavorite = checkpoint.favorite === true;
|
||||
|
||||
card.innerHTML = `
|
||||
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
|
||||
${isVideo ?
|
||||
`<video ${videoAttrs}>
|
||||
<source src="${versionedPreviewUrl}" type="video/mp4">
|
||||
</video>` :
|
||||
`<img src="${versionedPreviewUrl}" alt="${checkpoint.model_name}">`
|
||||
}
|
||||
<div class="card-header">
|
||||
${shouldBlur ?
|
||||
`<button class="toggle-blur-btn" title="Toggle blur">
|
||||
<i class="fas fa-eye"></i>
|
||||
</button>` : ''}
|
||||
<span class="base-model-label ${shouldBlur ? 'with-toggle' : ''}" title="${checkpoint.base_model}">
|
||||
${checkpoint.base_model}
|
||||
</span>
|
||||
<div class="card-actions">
|
||||
<i class="${isFavorite ? 'fas fa-star favorite-active' : 'far fa-star'}"
|
||||
title="${isFavorite ? 'Remove from favorites' : 'Add to favorites'}">
|
||||
</i>
|
||||
<i class="fas fa-globe"
|
||||
title="${checkpoint.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
|
||||
${!checkpoint.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
||||
</i>
|
||||
<i class="fas fa-copy"
|
||||
title="Copy Checkpoint Name">
|
||||
</i>
|
||||
<i class="fas fa-trash"
|
||||
title="Delete Model">
|
||||
</i>
|
||||
</div>
|
||||
</div>
|
||||
${shouldBlur ? `
|
||||
<div class="nsfw-overlay">
|
||||
<div class="nsfw-warning">
|
||||
<p>${nsfwText}</p>
|
||||
<button class="show-content-btn">Show</button>
|
||||
</div>
|
||||
</div>
|
||||
` : ''}
|
||||
<div class="card-footer">
|
||||
<div class="model-info">
|
||||
<span class="model-name">${checkpoint.model_name}</span>
|
||||
</div>
|
||||
<div class="card-actions">
|
||||
<i class="fas fa-folder-open"
|
||||
title="Open Example Images Folder">
|
||||
</i>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Add video auto-play on hover functionality if needed
|
||||
const videoElement = card.querySelector('video');
|
||||
if (videoElement && autoplayOnHover) {
|
||||
const cardPreview = card.querySelector('.card-preview');
|
||||
|
||||
// Remove autoplay attribute and pause initially
|
||||
videoElement.removeAttribute('autoplay');
|
||||
videoElement.pause();
|
||||
|
||||
// Add mouse events to trigger play/pause using event attributes
|
||||
cardPreview.setAttribute('onmouseenter', 'this.querySelector("video")?.play()');
|
||||
cardPreview.setAttribute('onmouseleave', 'const v=this.querySelector("video"); if(v){v.pause();v.currentTime=0;}');
|
||||
}
|
||||
|
||||
return card;
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||
import { refreshSingleCheckpointMetadata, saveModelMetadata, replaceCheckpointPreview, resetAndReload } from '../../api/checkpointApi.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/baseModelApi.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { showExcludeModal } from '../../utils/modalUtils.js';
|
||||
import { showDeleteModal, showExcludeModal } from '../../utils/modalUtils.js';
|
||||
|
||||
export class CheckpointContextMenu extends BaseContextMenu {
|
||||
constructor() {
|
||||
super('checkpointContextMenu', '.lora-card');
|
||||
super('checkpointContextMenu', '.model-card');
|
||||
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
|
||||
this.modelType = 'checkpoint';
|
||||
this.resetAndReload = resetAndReload;
|
||||
@@ -19,7 +19,7 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
||||
|
||||
// Implementation needed by the mixin
|
||||
async saveModelMetadata(filePath, data) {
|
||||
return saveModelMetadata(filePath, data);
|
||||
return getModelApiClient().saveModelMetadata(filePath, data);
|
||||
}
|
||||
|
||||
handleMenuAction(action) {
|
||||
@@ -28,6 +28,8 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
||||
return;
|
||||
}
|
||||
|
||||
const apiClient = getModelApiClient();
|
||||
|
||||
// Otherwise handle checkpoint-specific actions
|
||||
switch(action) {
|
||||
case 'details':
|
||||
@@ -36,13 +38,10 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
||||
break;
|
||||
case 'replace-preview':
|
||||
// Add new action for replacing preview images
|
||||
replaceCheckpointPreview(this.currentCard.dataset.filepath);
|
||||
apiClient.replaceModelPreview(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'delete':
|
||||
// Delete checkpoint
|
||||
if (this.currentCard.querySelector('.fa-trash')) {
|
||||
this.currentCard.querySelector('.fa-trash').click();
|
||||
}
|
||||
showDeleteModal(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'copyname':
|
||||
// Copy checkpoint name
|
||||
@@ -52,14 +51,14 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
||||
break;
|
||||
case 'refresh-metadata':
|
||||
// Refresh metadata from CivitAI
|
||||
refreshSingleCheckpointMetadata(this.currentCard.dataset.filepath);
|
||||
apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'move':
|
||||
// Move to folder (placeholder)
|
||||
showToast('Move to folder feature coming soon', 'info');
|
||||
break;
|
||||
case 'exclude':
|
||||
showExcludeModal(this.currentCard.dataset.filepath, 'checkpoint');
|
||||
showExcludeModal(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
68
static/js/components/ContextMenu/EmbeddingContextMenu.js
Normal file
68
static/js/components/ContextMenu/EmbeddingContextMenu.js
Normal file
@@ -0,0 +1,68 @@
|
||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/baseModelApi.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { showDeleteModal, showExcludeModal } from '../../utils/modalUtils.js';
|
||||
|
||||
export class EmbeddingContextMenu extends BaseContextMenu {
|
||||
constructor() {
|
||||
super('embeddingContextMenu', '.model-card');
|
||||
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
|
||||
this.modelType = 'embedding';
|
||||
this.resetAndReload = resetAndReload;
|
||||
|
||||
// Initialize NSFW Level Selector events
|
||||
if (this.nsfwSelector) {
|
||||
this.initNSFWSelector();
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation needed by the mixin
|
||||
async saveModelMetadata(filePath, data) {
|
||||
return getModelApiClient().saveModelMetadata(filePath, data);
|
||||
}
|
||||
|
||||
handleMenuAction(action) {
|
||||
// First try to handle with common actions
|
||||
if (ModelContextMenuMixin.handleCommonMenuActions.call(this, action)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const apiClient = getModelApiClient();
|
||||
|
||||
// Otherwise handle embedding-specific actions
|
||||
switch(action) {
|
||||
case 'details':
|
||||
// Show embedding details
|
||||
this.currentCard.click();
|
||||
break;
|
||||
case 'replace-preview':
|
||||
// Add new action for replacing preview images
|
||||
apiClient.replaceModelPreview(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'delete':
|
||||
showDeleteModal(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'copyname':
|
||||
// Copy embedding name
|
||||
if (this.currentCard.querySelector('.fa-copy')) {
|
||||
this.currentCard.querySelector('.fa-copy').click();
|
||||
}
|
||||
break;
|
||||
case 'refresh-metadata':
|
||||
// Refresh metadata from CivitAI
|
||||
apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'move':
|
||||
// Move to folder (placeholder)
|
||||
showToast('Move to folder feature coming soon', 'info');
|
||||
break;
|
||||
case 'exclude':
|
||||
showExcludeModal(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mix in shared methods
|
||||
Object.assign(EmbeddingContextMenu.prototype, ModelContextMenuMixin);
|
||||
@@ -1,12 +1,12 @@
|
||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||
import { refreshSingleLoraMetadata, saveModelMetadata, replacePreview, resetAndReload } from '../../api/loraApi.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/baseModelApi.js';
|
||||
import { copyToClipboard, sendLoraToWorkflow } from '../../utils/uiHelpers.js';
|
||||
import { showExcludeModal, showDeleteModal } from '../../utils/modalUtils.js';
|
||||
|
||||
export class LoraContextMenu extends BaseContextMenu {
|
||||
constructor() {
|
||||
super('loraContextMenu', '.lora-card');
|
||||
super('loraContextMenu', '.model-card');
|
||||
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
|
||||
this.modelType = 'lora';
|
||||
this.resetAndReload = resetAndReload;
|
||||
@@ -19,7 +19,7 @@ export class LoraContextMenu extends BaseContextMenu {
|
||||
|
||||
// Use the saveModelMetadata implementation from loraApi
|
||||
async saveModelMetadata(filePath, data) {
|
||||
return saveModelMetadata(filePath, data);
|
||||
return getModelApiClient().saveModelMetadata(filePath, data);
|
||||
}
|
||||
|
||||
handleMenuAction(action, menuItem) {
|
||||
@@ -48,7 +48,7 @@ export class LoraContextMenu extends BaseContextMenu {
|
||||
break;
|
||||
case 'replace-preview':
|
||||
// Add a new action for replacing preview images
|
||||
replacePreview(this.currentCard.dataset.filepath);
|
||||
getModelApiClient().replaceModelPreview(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'delete':
|
||||
// Call showDeleteModal directly instead of clicking the trash button
|
||||
@@ -58,7 +58,7 @@ export class LoraContextMenu extends BaseContextMenu {
|
||||
moveManager.showMoveModal(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'refresh-metadata':
|
||||
refreshSingleLoraMetadata(this.currentCard.dataset.filepath);
|
||||
getModelApiClient().refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'exclude':
|
||||
showExcludeModal(this.currentCard.dataset.filepath);
|
||||
|
||||
@@ -125,7 +125,7 @@ export const ModelContextMenuMixin = {
|
||||
|
||||
const endpoint = this.modelType === 'checkpoint' ?
|
||||
'/api/checkpoints/relink-civitai' :
|
||||
'/api/relink-civitai';
|
||||
'/api/loras/relink-civitai';
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
|
||||
@@ -7,7 +7,7 @@ import { state } from '../../state/index.js';
|
||||
|
||||
export class RecipeContextMenu extends BaseContextMenu {
|
||||
constructor() {
|
||||
super('recipeContextMenu', '.lora-card');
|
||||
super('recipeContextMenu', '.model-card');
|
||||
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
|
||||
this.modelType = 'recipe';
|
||||
|
||||
@@ -209,9 +209,9 @@ export class RecipeContextMenu extends BaseContextMenu {
|
||||
|
||||
// Determine which endpoint to use based on available data
|
||||
if (lora.modelVersionId) {
|
||||
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
|
||||
endpoint = `/api/loras/civitai/model/version/${lora.modelVersionId}`;
|
||||
} else if (lora.hash) {
|
||||
endpoint = `/api/civitai/model/hash/${lora.hash}`;
|
||||
endpoint = `/api/loras/civitai/model/hash/${lora.hash}`;
|
||||
} else {
|
||||
console.error("Missing both hash and modelVersionId for lora:", lora);
|
||||
return null;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export { LoraContextMenu } from './LoraContextMenu.js';
|
||||
export { RecipeContextMenu } from './RecipeContextMenu.js';
|
||||
export { CheckpointContextMenu } from './CheckpointContextMenu.js';
|
||||
export { EmbeddingContextMenu } from './EmbeddingContextMenu.js';
|
||||
export { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||
@@ -243,7 +243,7 @@ export class DuplicatesManager {
|
||||
checkboxes.forEach(checkbox => {
|
||||
checkbox.checked = !allSelected;
|
||||
const recipeId = checkbox.dataset.recipeId;
|
||||
const card = checkbox.closest('.lora-card');
|
||||
const card = checkbox.closest('.model-card');
|
||||
|
||||
if (!allSelected) {
|
||||
this.selectedForDeletion.add(recipeId);
|
||||
@@ -268,7 +268,7 @@ export class DuplicatesManager {
|
||||
checkboxes.forEach(checkbox => {
|
||||
checkbox.checked = true;
|
||||
this.selectedForDeletion.add(checkbox.dataset.recipeId);
|
||||
checkbox.closest('.lora-card').classList.add('duplicate-selected');
|
||||
checkbox.closest('.model-card').classList.add('duplicate-selected');
|
||||
});
|
||||
|
||||
// Update the button text
|
||||
@@ -299,7 +299,7 @@ export class DuplicatesManager {
|
||||
if (checkbox) {
|
||||
checkbox.checked = true;
|
||||
this.selectedForDeletion.add(recipeId);
|
||||
checkbox.closest('.lora-card').classList.add('duplicate-selected');
|
||||
checkbox.closest('.model-card').classList.add('duplicate-selected');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,7 +310,7 @@ export class DuplicatesManager {
|
||||
if (latestCheckbox) {
|
||||
latestCheckbox.checked = false;
|
||||
this.selectedForDeletion.delete(latestId);
|
||||
latestCheckbox.closest('.lora-card').classList.remove('duplicate-selected');
|
||||
latestCheckbox.closest('.model-card').classList.remove('duplicate-selected');
|
||||
}
|
||||
|
||||
this.updateSelectedCount();
|
||||
|
||||
@@ -26,6 +26,7 @@ export class HeaderManager {
|
||||
const path = window.location.pathname;
|
||||
if (path.includes('/loras/recipes')) return 'recipes';
|
||||
if (path.includes('/checkpoints')) return 'checkpoints';
|
||||
if (path.includes('/embeddings')) return 'embeddings';
|
||||
if (path.includes('/statistics')) return 'statistics';
|
||||
if (path.includes('/loras')) return 'loras';
|
||||
return 'unknown';
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
import { state, getCurrentPageState } from '../state/index.js';
|
||||
import { formatDate } from '../utils/formatters.js';
|
||||
import { resetAndReload as resetAndReloadLoras } from '../api/loraApi.js';
|
||||
import { resetAndReload as resetAndReloadCheckpoints } from '../api/checkpointApi.js';
|
||||
import { resetAndReload} from '../api/baseModelApi.js';
|
||||
import { LoadingManager } from '../managers/LoadingManager.js';
|
||||
|
||||
export class ModelDuplicatesManager {
|
||||
@@ -184,7 +183,7 @@ export class ModelDuplicatesManager {
|
||||
document.body.classList.remove('duplicate-mode');
|
||||
|
||||
// Clear the model grid first
|
||||
const modelGrid = document.getElementById(this.modelType === 'loras' ? 'loraGrid' : 'checkpointGrid');
|
||||
const modelGrid = document.getElementById('modelGrid');
|
||||
if (modelGrid) {
|
||||
modelGrid.innerHTML = '';
|
||||
}
|
||||
@@ -241,7 +240,7 @@ export class ModelDuplicatesManager {
|
||||
}
|
||||
|
||||
renderDuplicateGroups() {
|
||||
const modelGrid = document.getElementById(this.modelType === 'loras' ? 'loraGrid' : 'checkpointGrid');
|
||||
const modelGrid = document.getElementById('modelGrid');
|
||||
if (!modelGrid) return;
|
||||
|
||||
// Clear existing content
|
||||
@@ -331,7 +330,7 @@ export class ModelDuplicatesManager {
|
||||
renderModelCard(model, groupHash) {
|
||||
// Create basic card structure
|
||||
const card = document.createElement('div');
|
||||
card.className = 'lora-card duplicate';
|
||||
card.className = 'model-card duplicate';
|
||||
card.dataset.hash = model.sha256;
|
||||
card.dataset.filePath = model.file_path;
|
||||
|
||||
@@ -550,7 +549,7 @@ export class ModelDuplicatesManager {
|
||||
checkboxes.forEach(checkbox => {
|
||||
checkbox.checked = !allSelected;
|
||||
const filePath = checkbox.dataset.filePath;
|
||||
const card = checkbox.closest('.lora-card');
|
||||
const card = checkbox.closest('.model-card');
|
||||
|
||||
if (!allSelected) {
|
||||
this.selectedForDeletion.add(filePath);
|
||||
@@ -622,12 +621,7 @@ export class ModelDuplicatesManager {
|
||||
|
||||
// If models were successfully deleted
|
||||
if (data.total_deleted > 0) {
|
||||
// Reload model data with updated folders
|
||||
if (this.modelType === 'loras') {
|
||||
await resetAndReloadLoras(true);
|
||||
} else {
|
||||
await resetAndReloadCheckpoints(true);
|
||||
}
|
||||
await resetAndReload(true);
|
||||
|
||||
// Check if there are still duplicates
|
||||
try {
|
||||
|
||||
@@ -17,7 +17,7 @@ class RecipeCard {
|
||||
|
||||
createCardElement() {
|
||||
const card = document.createElement('div');
|
||||
card.className = 'lora-card';
|
||||
card.className = 'model-card';
|
||||
card.dataset.filepath = this.recipe.file_path;
|
||||
card.dataset.title = this.recipe.title;
|
||||
card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0;
|
||||
|
||||
@@ -831,9 +831,9 @@ class RecipeModal {
|
||||
|
||||
// Determine which endpoint to use based on available data
|
||||
if (lora.modelVersionId) {
|
||||
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
|
||||
endpoint = `/api/loras/civitai/model/version/${lora.modelVersionId}`;
|
||||
} else if (lora.hash) {
|
||||
endpoint = `/api/civitai/model/hash/${lora.hash}`;
|
||||
endpoint = `/api/loras/civitai/model/hash/${lora.hash}`;
|
||||
} else {
|
||||
console.error("Missing both hash and modelVersionId for lora:", lora);
|
||||
return null;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// AlphabetBar.js - Component for alphabet filtering
|
||||
import { getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||
import { getCurrentPageState } from '../../state/index.js';
|
||||
import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js';
|
||||
import { resetAndReload } from '../../api/loraApi.js';
|
||||
import { resetAndReload } from '../../api/baseModelApi.js';
|
||||
|
||||
/**
|
||||
* AlphabetBar class - Handles the alphabet filtering UI and interactions
|
||||
@@ -227,7 +227,7 @@ export class AlphabetBar {
|
||||
this.updateToggleIndicator();
|
||||
|
||||
// Trigger a reload with the new filter
|
||||
resetAndReload(true);
|
||||
resetAndReload(false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
/**
|
||||
* ModelDescription.js
|
||||
* Handles checkpoint model descriptions
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
|
||||
/**
|
||||
* Set up tab switching functionality
|
||||
*/
|
||||
export function setupTabSwitching() {
|
||||
const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn');
|
||||
|
||||
tabButtons.forEach(button => {
|
||||
button.addEventListener('click', () => {
|
||||
// Remove active class from all tabs
|
||||
document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn =>
|
||||
btn.classList.remove('active')
|
||||
);
|
||||
document.querySelectorAll('.tab-content .tab-pane').forEach(tab =>
|
||||
tab.classList.remove('active')
|
||||
);
|
||||
|
||||
// Add active class to clicked tab
|
||||
button.classList.add('active');
|
||||
const tabId = `${button.dataset.tab}-tab`;
|
||||
document.getElementById(tabId).classList.add('active');
|
||||
|
||||
// If switching to description tab, make sure content is properly loaded and displayed
|
||||
if (button.dataset.tab === 'description') {
|
||||
const descriptionContent = document.querySelector('.model-description-content');
|
||||
if (descriptionContent) {
|
||||
const hasContent = descriptionContent.innerHTML.trim() !== '';
|
||||
document.querySelector('.model-description-loading')?.classList.add('hidden');
|
||||
|
||||
// If no content, show a message
|
||||
if (!hasContent) {
|
||||
descriptionContent.innerHTML = '<div class="no-description">No model description available</div>';
|
||||
descriptionContent.classList.remove('hidden');
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load model description from API
|
||||
* @param {string} modelId - The Civitai model ID
|
||||
* @param {string} filePath - File path for the model
|
||||
*/
|
||||
export async function loadModelDescription(modelId, filePath) {
|
||||
try {
|
||||
const descriptionContainer = document.querySelector('.model-description-content');
|
||||
const loadingElement = document.querySelector('.model-description-loading');
|
||||
|
||||
if (!descriptionContainer || !loadingElement) return;
|
||||
|
||||
// Show loading indicator
|
||||
loadingElement.classList.remove('hidden');
|
||||
descriptionContainer.classList.add('hidden');
|
||||
|
||||
// Try to get model description from API
|
||||
const response = await fetch(`/api/checkpoint-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch model description: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success && data.description) {
|
||||
// Update the description content
|
||||
descriptionContainer.innerHTML = data.description;
|
||||
|
||||
// Process any links in the description to open in new tab
|
||||
const links = descriptionContainer.querySelectorAll('a');
|
||||
links.forEach(link => {
|
||||
link.setAttribute('target', '_blank');
|
||||
link.setAttribute('rel', 'noopener noreferrer');
|
||||
});
|
||||
|
||||
// Show the description and hide loading indicator
|
||||
descriptionContainer.classList.remove('hidden');
|
||||
loadingElement.classList.add('hidden');
|
||||
} else {
|
||||
throw new Error(data.error || 'No description available');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading model description:', error);
|
||||
const loadingElement = document.querySelector('.model-description-loading');
|
||||
if (loadingElement) {
|
||||
loadingElement.innerHTML = `<div class="error-message">Failed to load model description. ${error.message}</div>`;
|
||||
}
|
||||
|
||||
// Show empty state message in the description container
|
||||
const descriptionContainer = document.querySelector('.model-description-content');
|
||||
if (descriptionContainer) {
|
||||
descriptionContainer.innerHTML = '<div class="no-description">No model description available</div>';
|
||||
descriptionContainer.classList.remove('hidden');
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,471 +0,0 @@
|
||||
/**
|
||||
* ModelTags.js
|
||||
* Module for handling checkpoint model tag editing functionality
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { saveModelMetadata } from '../../api/checkpointApi.js';
|
||||
|
||||
// Preset tag suggestions
|
||||
const PRESET_TAGS = [
|
||||
'character', 'style', 'concept', 'clothing', 'base model',
|
||||
'poses', 'background', 'vehicle', 'buildings',
|
||||
'objects', 'animal'
|
||||
];
|
||||
|
||||
// Create a named function so we can remove it later
|
||||
let saveTagsHandler = null;
|
||||
|
||||
/**
|
||||
* Set up tag editing mode
|
||||
*/
|
||||
export function setupTagEditMode() {
|
||||
const editBtn = document.querySelector('.edit-tags-btn');
|
||||
if (!editBtn) return;
|
||||
|
||||
// Store original tags for restoring on cancel
|
||||
let originalTags = [];
|
||||
|
||||
// Remove any previously attached click handler
|
||||
if (editBtn._hasClickHandler) {
|
||||
editBtn.removeEventListener('click', editBtn._clickHandler);
|
||||
}
|
||||
|
||||
// Create new handler and store reference
|
||||
const editBtnClickHandler = function() {
|
||||
const tagsSection = document.querySelector('.model-tags-container');
|
||||
const isEditMode = tagsSection.classList.toggle('edit-mode');
|
||||
const filePath = this.dataset.filePath;
|
||||
|
||||
// Toggle edit mode UI elements
|
||||
const compactTagsDisplay = tagsSection.querySelector('.model-tags-compact');
|
||||
const tagsEditContainer = tagsSection.querySelector('.metadata-edit-container');
|
||||
|
||||
if (isEditMode) {
|
||||
// Enter edit mode
|
||||
this.innerHTML = '<i class="fas fa-times"></i>'; // Change to cancel icon
|
||||
this.title = "Cancel editing";
|
||||
|
||||
// Get all tags from tooltip, not just the visible ones in compact display
|
||||
originalTags = Array.from(
|
||||
tagsSection.querySelectorAll('.tooltip-tag')
|
||||
).map(tag => tag.textContent);
|
||||
|
||||
// Hide compact display, show edit container
|
||||
compactTagsDisplay.style.display = 'none';
|
||||
|
||||
// If edit container doesn't exist yet, create it
|
||||
if (!tagsEditContainer) {
|
||||
const editContainer = document.createElement('div');
|
||||
editContainer.className = 'metadata-edit-container';
|
||||
|
||||
// Move the edit button inside the container header for better visibility
|
||||
const editBtnClone = editBtn.cloneNode(true);
|
||||
editBtnClone.classList.add('metadata-header-btn');
|
||||
|
||||
// Create edit UI with edit button in the header
|
||||
editContainer.innerHTML = createTagEditUI(originalTags, editBtnClone.outerHTML);
|
||||
tagsSection.appendChild(editContainer);
|
||||
|
||||
// Setup the tag input field behavior
|
||||
setupTagInput();
|
||||
|
||||
// Create and add preset suggestions dropdown
|
||||
const tagForm = editContainer.querySelector('.metadata-add-form');
|
||||
const suggestionsDropdown = createSuggestionsDropdown(originalTags);
|
||||
tagForm.appendChild(suggestionsDropdown);
|
||||
|
||||
// Setup delete buttons for existing tags
|
||||
setupDeleteButtons();
|
||||
|
||||
// Transfer click event from original button to the cloned one
|
||||
const newEditBtn = editContainer.querySelector('.metadata-header-btn');
|
||||
if (newEditBtn) {
|
||||
newEditBtn.addEventListener('click', function() {
|
||||
editBtn.click();
|
||||
});
|
||||
}
|
||||
|
||||
// Hide the original button when in edit mode
|
||||
editBtn.style.display = 'none';
|
||||
} else {
|
||||
// Just show the existing edit container
|
||||
tagsEditContainer.style.display = 'block';
|
||||
editBtn.style.display = 'none';
|
||||
}
|
||||
} else {
|
||||
// Exit edit mode
|
||||
this.innerHTML = '<i class="fas fa-pencil-alt"></i>'; // Change back to edit icon
|
||||
this.title = "Edit tags";
|
||||
editBtn.style.display = 'block';
|
||||
|
||||
// Show compact display, hide edit container
|
||||
compactTagsDisplay.style.display = 'flex';
|
||||
if (tagsEditContainer) tagsEditContainer.style.display = 'none';
|
||||
|
||||
// Check if we're exiting edit mode due to "Save" or "Cancel"
|
||||
if (!this.dataset.skipRestore) {
|
||||
// If canceling, restore original tags
|
||||
restoreOriginalTags(tagsSection, originalTags);
|
||||
} else {
|
||||
// Reset the skip restore flag
|
||||
delete this.dataset.skipRestore;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Store the handler reference on the button itself
|
||||
editBtn._clickHandler = editBtnClickHandler;
|
||||
editBtn._hasClickHandler = true;
|
||||
editBtn.addEventListener('click', editBtnClickHandler);
|
||||
|
||||
// Clean up any previous document click handler
|
||||
if (saveTagsHandler) {
|
||||
document.removeEventListener('click', saveTagsHandler);
|
||||
}
|
||||
|
||||
// Create new save handler and store reference
|
||||
saveTagsHandler = function(e) {
|
||||
if (e.target.classList.contains('save-tags-btn') ||
|
||||
e.target.closest('.save-tags-btn')) {
|
||||
saveTags();
|
||||
}
|
||||
};
|
||||
|
||||
// Add the new handler
|
||||
document.addEventListener('click', saveTagsHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the tag editing UI
|
||||
* @param {Array} currentTags - Current tags
|
||||
* @param {string} editBtnHTML - HTML for the edit button to include in header
|
||||
* @returns {string} HTML markup for tag editing UI
|
||||
*/
|
||||
function createTagEditUI(currentTags, editBtnHTML = '') {
|
||||
return `
|
||||
<div class="metadata-edit-content">
|
||||
<div class="metadata-edit-header">
|
||||
<label>Edit Tags</label>
|
||||
${editBtnHTML}
|
||||
</div>
|
||||
<div class="metadata-items">
|
||||
${currentTags.map(tag => `
|
||||
<div class="metadata-item" data-tag="${tag}">
|
||||
<span class="metadata-item-content">${tag}</span>
|
||||
<button class="metadata-delete-btn">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
`).join('')}
|
||||
</div>
|
||||
<div class="metadata-edit-controls">
|
||||
<button class="save-tags-btn" title="Save changes">
|
||||
<i class="fas fa-save"></i> Save
|
||||
</button>
|
||||
</div>
|
||||
<div class="metadata-add-form">
|
||||
<input type="text" class="metadata-input" placeholder="Type to add or click suggestions below">
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create suggestions dropdown with preset tags
|
||||
* @param {Array} existingTags - Already added tags
|
||||
* @returns {HTMLElement} - Dropdown element
|
||||
*/
|
||||
function createSuggestionsDropdown(existingTags = []) {
|
||||
const dropdown = document.createElement('div');
|
||||
dropdown.className = 'metadata-suggestions-dropdown';
|
||||
|
||||
// Create header
|
||||
const header = document.createElement('div');
|
||||
header.className = 'metadata-suggestions-header';
|
||||
header.innerHTML = `
|
||||
<span>Suggested Tags</span>
|
||||
<small>Click to add</small>
|
||||
`;
|
||||
dropdown.appendChild(header);
|
||||
|
||||
// Create tag container
|
||||
const container = document.createElement('div');
|
||||
container.className = 'metadata-suggestions-container';
|
||||
|
||||
// Add each preset tag as a suggestion
|
||||
PRESET_TAGS.forEach(tag => {
|
||||
const isAdded = existingTags.includes(tag);
|
||||
|
||||
const item = document.createElement('div');
|
||||
item.className = `metadata-suggestion-item ${isAdded ? 'already-added' : ''}`;
|
||||
item.title = tag;
|
||||
item.innerHTML = `
|
||||
<span class="metadata-suggestion-text">${tag}</span>
|
||||
${isAdded ? '<span class="added-indicator"><i class="fas fa-check"></i></span>' : ''}
|
||||
`;
|
||||
|
||||
if (!isAdded) {
|
||||
item.addEventListener('click', () => {
|
||||
addNewTag(tag);
|
||||
|
||||
// Also populate the input field for potential editing
|
||||
const input = document.querySelector('.metadata-input');
|
||||
if (input) input.value = tag;
|
||||
|
||||
// Focus on the input
|
||||
if (input) input.focus();
|
||||
|
||||
// Update dropdown without removing it
|
||||
updateSuggestionsDropdown();
|
||||
});
|
||||
}
|
||||
|
||||
container.appendChild(item);
|
||||
});
|
||||
|
||||
dropdown.appendChild(container);
|
||||
return dropdown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up tag input behavior
|
||||
*/
|
||||
function setupTagInput() {
|
||||
const tagInput = document.querySelector('.metadata-input');
|
||||
|
||||
if (tagInput) {
|
||||
tagInput.addEventListener('keydown', function(e) {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
addNewTag(this.value);
|
||||
this.value = ''; // Clear input after adding
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up delete buttons for tags
|
||||
*/
|
||||
function setupDeleteButtons() {
|
||||
document.querySelectorAll('.metadata-delete-btn').forEach(btn => {
|
||||
btn.addEventListener('click', function(e) {
|
||||
e.stopPropagation();
|
||||
const tag = this.closest('.metadata-item');
|
||||
tag.remove();
|
||||
|
||||
// Update status of items in the suggestion dropdown
|
||||
updateSuggestionsDropdown();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new tag
|
||||
* @param {string} tag - Tag to add
|
||||
*/
|
||||
function addNewTag(tag) {
|
||||
tag = tag.trim().toLowerCase();
|
||||
if (!tag) return;
|
||||
|
||||
const tagsContainer = document.querySelector('.metadata-items');
|
||||
if (!tagsContainer) return;
|
||||
|
||||
// Validation: Check length
|
||||
if (tag.length > 30) {
|
||||
showToast('Tag should not exceed 30 characters', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
// Validation: Check total number
|
||||
const currentTags = tagsContainer.querySelectorAll('.metadata-item');
|
||||
if (currentTags.length >= 30) {
|
||||
showToast('Maximum 30 tags allowed', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
// Validation: Check for duplicates
|
||||
const existingTags = Array.from(currentTags).map(tag => tag.dataset.tag);
|
||||
if (existingTags.includes(tag)) {
|
||||
showToast('This tag already exists', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
// Create new tag
|
||||
const newTag = document.createElement('div');
|
||||
newTag.className = 'metadata-item';
|
||||
newTag.dataset.tag = tag;
|
||||
newTag.innerHTML = `
|
||||
<span class="metadata-item-content">${tag}</span>
|
||||
<button class="metadata-delete-btn">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
`;
|
||||
|
||||
// Add event listener to delete button
|
||||
const deleteBtn = newTag.querySelector('.metadata-delete-btn');
|
||||
deleteBtn.addEventListener('click', function(e) {
|
||||
e.stopPropagation();
|
||||
newTag.remove();
|
||||
|
||||
// Update status of items in the suggestion dropdown
|
||||
updateSuggestionsDropdown();
|
||||
});
|
||||
|
||||
tagsContainer.appendChild(newTag);
|
||||
|
||||
// Update status of items in the suggestions dropdown
|
||||
updateSuggestionsDropdown();
|
||||
}
|
||||
|
||||
/**
|
||||
* Update status of items in the suggestions dropdown
|
||||
*/
|
||||
function updateSuggestionsDropdown() {
|
||||
const dropdown = document.querySelector('.metadata-suggestions-dropdown');
|
||||
if (!dropdown) return;
|
||||
|
||||
// Get all current tags
|
||||
const currentTags = document.querySelectorAll('.metadata-item');
|
||||
const existingTags = Array.from(currentTags).map(tag => tag.dataset.tag);
|
||||
|
||||
// Update status of each item in dropdown
|
||||
dropdown.querySelectorAll('.metadata-suggestion-item').forEach(item => {
|
||||
const tagText = item.querySelector('.metadata-suggestion-text').textContent;
|
||||
const isAdded = existingTags.includes(tagText);
|
||||
|
||||
if (isAdded) {
|
||||
item.classList.add('already-added');
|
||||
|
||||
// Add indicator if it doesn't exist
|
||||
let indicator = item.querySelector('.added-indicator');
|
||||
if (!indicator) {
|
||||
indicator = document.createElement('span');
|
||||
indicator.className = 'added-indicator';
|
||||
indicator.innerHTML = '<i class="fas fa-check"></i>';
|
||||
item.appendChild(indicator);
|
||||
}
|
||||
|
||||
// Remove click event
|
||||
item.onclick = null;
|
||||
} else {
|
||||
// Re-enable items that are no longer in the list
|
||||
item.classList.remove('already-added');
|
||||
|
||||
// Remove indicator if it exists
|
||||
const indicator = item.querySelector('.added-indicator');
|
||||
if (indicator) indicator.remove();
|
||||
|
||||
// Restore click event if not already set
|
||||
if (!item.onclick) {
|
||||
item.onclick = () => {
|
||||
const tag = item.querySelector('.metadata-suggestion-text').textContent;
|
||||
addNewTag(tag);
|
||||
|
||||
// Also populate the input field
|
||||
const input = document.querySelector('.metadata-input');
|
||||
if (input) input.value = tag;
|
||||
|
||||
// Focus the input
|
||||
if (input) input.focus();
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Restore original tags when canceling edit
|
||||
* @param {HTMLElement} section - The tags section
|
||||
* @param {Array} originalTags - Original tags array
|
||||
*/
|
||||
function restoreOriginalTags(section, originalTags) {
|
||||
// Nothing to do here as we're just hiding the edit UI
|
||||
// and showing the original compact tags which weren't modified
|
||||
}
|
||||
|
||||
/**
|
||||
* Save tags
|
||||
*/
|
||||
async function saveTags() {
|
||||
const editBtn = document.querySelector('.edit-tags-btn');
|
||||
if (!editBtn) return;
|
||||
|
||||
const filePath = editBtn.dataset.filePath;
|
||||
const tagElements = document.querySelectorAll('.metadata-item');
|
||||
const tags = Array.from(tagElements).map(tag => tag.dataset.tag);
|
||||
|
||||
// Get original tags to compare
|
||||
const originalTagElements = document.querySelectorAll('.tooltip-tag');
|
||||
const originalTags = Array.from(originalTagElements).map(tag => tag.textContent);
|
||||
|
||||
// Check if tags have actually changed
|
||||
const tagsChanged = JSON.stringify(tags) !== JSON.stringify(originalTags);
|
||||
|
||||
if (!tagsChanged) {
|
||||
// No changes made, just exit edit mode without API call
|
||||
editBtn.dataset.skipRestore = "true";
|
||||
editBtn.click();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Save tags metadata
|
||||
await saveModelMetadata(filePath, { tags: tags });
|
||||
|
||||
// Set flag to skip restoring original tags when exiting edit mode
|
||||
editBtn.dataset.skipRestore = "true";
|
||||
|
||||
// Update the compact tags display
|
||||
const compactTagsContainer = document.querySelector('.model-tags-container');
|
||||
if (compactTagsContainer) {
|
||||
// Generate new compact tags HTML
|
||||
const compactTagsDisplay = compactTagsContainer.querySelector('.model-tags-compact');
|
||||
|
||||
if (compactTagsDisplay) {
|
||||
// Clear current tags
|
||||
compactTagsDisplay.innerHTML = '';
|
||||
|
||||
// Add visible tags (up to 5)
|
||||
const visibleTags = tags.slice(0, 5);
|
||||
visibleTags.forEach(tag => {
|
||||
const span = document.createElement('span');
|
||||
span.className = 'model-tag-compact';
|
||||
span.textContent = tag;
|
||||
compactTagsDisplay.appendChild(span);
|
||||
});
|
||||
|
||||
// Add more indicator if needed
|
||||
const remainingCount = Math.max(0, tags.length - 5);
|
||||
if (remainingCount > 0) {
|
||||
const more = document.createElement('span');
|
||||
more.className = 'model-tag-more';
|
||||
more.dataset.count = remainingCount;
|
||||
more.textContent = `+${remainingCount}`;
|
||||
compactTagsDisplay.appendChild(more);
|
||||
}
|
||||
}
|
||||
|
||||
// Update tooltip content
|
||||
const tooltipContent = compactTagsContainer.querySelector('.tooltip-content');
|
||||
if (tooltipContent) {
|
||||
tooltipContent.innerHTML = '';
|
||||
|
||||
tags.forEach(tag => {
|
||||
const span = document.createElement('span');
|
||||
span.className = 'tooltip-tag';
|
||||
span.textContent = tag;
|
||||
tooltipContent.appendChild(span);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Exit edit mode
|
||||
editBtn.click();
|
||||
|
||||
showToast('Tags updated successfully', 'success');
|
||||
} catch (error) {
|
||||
console.error('Error saving tags:', error);
|
||||
showToast('Failed to update tags', 'error');
|
||||
}
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
/**
|
||||
* CheckpointModal - Main entry point
|
||||
*
|
||||
* Modularized checkpoint modal component that handles checkpoint model details display
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { modalManager } from '../../managers/ModalManager.js';
|
||||
import {
|
||||
toggleShowcase,
|
||||
setupShowcaseScroll,
|
||||
scrollToTop,
|
||||
loadExampleImages
|
||||
} from '../shared/showcase/ShowcaseView.js';
|
||||
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
||||
import {
|
||||
setupModelNameEditing,
|
||||
setupBaseModelEditing,
|
||||
setupFileNameEditing
|
||||
} from './ModelMetadata.js';
|
||||
import { setupTagEditMode } from './ModelTags.js'; // Add import for tag editing
|
||||
import { saveModelMetadata } from '../../api/checkpointApi.js';
|
||||
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
|
||||
|
||||
/**
|
||||
* Display the checkpoint modal with the given checkpoint data
|
||||
* @param {Object} checkpoint - Checkpoint data object
|
||||
*/
|
||||
export function showCheckpointModal(checkpoint) {
|
||||
const content = `
|
||||
<div class="modal-content">
|
||||
<button class="close" onclick="modalManager.closeModal('checkpointModal')">×</button>
|
||||
<header class="modal-header">
|
||||
<div class="model-name-header">
|
||||
<h2 class="model-name-content">${checkpoint.model_name || 'Checkpoint Details'}</h2>
|
||||
<button class="edit-model-name-btn" title="Edit model name">
|
||||
<i class="fas fa-pencil-alt"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
${checkpoint.civitai?.creator ? `
|
||||
<div class="creator-info">
|
||||
${checkpoint.civitai.creator.image ?
|
||||
`<div class="creator-avatar">
|
||||
<img src="${checkpoint.civitai.creator.image}" alt="${checkpoint.civitai.creator.username}" onerror="this.onerror=null; this.src='static/icons/user-placeholder.png';">
|
||||
</div>` :
|
||||
`<div class="creator-avatar creator-placeholder">
|
||||
<i class="fas fa-user"></i>
|
||||
</div>`
|
||||
}
|
||||
<span class="creator-username">${checkpoint.civitai.creator.username}</span>
|
||||
</div>` : ''}
|
||||
|
||||
${renderCompactTags(checkpoint.tags || [], checkpoint.file_path)}
|
||||
</header>
|
||||
|
||||
<div class="modal-body">
|
||||
<div class="info-section">
|
||||
<div class="info-grid">
|
||||
<div class="info-item">
|
||||
<label>Version</label>
|
||||
<span>${checkpoint.civitai?.name || 'N/A'}</span>
|
||||
</div>
|
||||
<div class="info-item">
|
||||
<label>File Name</label>
|
||||
<div class="file-name-wrapper">
|
||||
<span id="file-name" class="file-name-content">${checkpoint.file_name || 'N/A'}</span>
|
||||
<button class="edit-file-name-btn" title="Edit file name">
|
||||
<i class="fas fa-pencil-alt"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="info-item location-size">
|
||||
<div class="location-wrapper">
|
||||
<label>Location</label>
|
||||
<span class="file-path">${checkpoint.file_path.replace(/[^/]+$/, '')}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="info-item base-size">
|
||||
<div class="base-wrapper">
|
||||
<label>Base Model</label>
|
||||
<div class="base-model-display">
|
||||
<span class="base-model-content">${checkpoint.base_model || 'Unknown'}</span>
|
||||
<button class="edit-base-model-btn" title="Edit base model">
|
||||
<i class="fas fa-pencil-alt"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="size-wrapper">
|
||||
<label>Size</label>
|
||||
<span>${formatFileSize(checkpoint.file_size)}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="info-item notes">
|
||||
<label>Additional Notes</label>
|
||||
<div class="editable-field">
|
||||
<div class="notes-content" contenteditable="true" spellcheck="false">${checkpoint.notes || 'Add your notes here...'}</div>
|
||||
<button class="save-btn" onclick="saveCheckpointNotes('${checkpoint.file_path}')">
|
||||
<i class="fas fa-save"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="info-item full-width">
|
||||
<label>About this version</label>
|
||||
<div class="description-text">${checkpoint.civitai?.description || 'N/A'}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="showcase-section" data-model-hash="${checkpoint.sha256 || ''}" data-filepath="${checkpoint.file_path}">
|
||||
<div class="showcase-tabs">
|
||||
<button class="tab-btn active" data-tab="showcase">Examples</button>
|
||||
<button class="tab-btn" data-tab="description">Model Description</button>
|
||||
</div>
|
||||
|
||||
<div class="tab-content">
|
||||
<div id="showcase-tab" class="tab-pane active">
|
||||
<div class="recipes-loading">
|
||||
<i class="fas fa-spinner fa-spin"></i> Loading recipes...
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="description-tab" class="tab-pane">
|
||||
<div class="model-description-container">
|
||||
<div class="model-description-loading">
|
||||
<i class="fas fa-spinner fa-spin"></i> Loading model description...
|
||||
</div>
|
||||
<div class="model-description-content">
|
||||
${checkpoint.modelDescription || ''}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button class="back-to-top" onclick="scrollToTopCheckpoint(this)">
|
||||
<i class="fas fa-arrow-up"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
modalManager.showModal('checkpointModal', content);
|
||||
setupEditableFields(checkpoint.file_path);
|
||||
setupShowcaseScroll('checkpointModal');
|
||||
setupTabSwitching();
|
||||
setupTagTooltip();
|
||||
setupTagEditMode(); // Initialize tag editing functionality
|
||||
setupModelNameEditing(checkpoint.file_path);
|
||||
setupBaseModelEditing(checkpoint.file_path);
|
||||
setupFileNameEditing(checkpoint.file_path);
|
||||
|
||||
// If we have a model ID but no description, fetch it
|
||||
if (checkpoint.civitai?.modelId && !checkpoint.modelDescription) {
|
||||
loadModelDescription(checkpoint.civitai.modelId, checkpoint.file_path);
|
||||
}
|
||||
|
||||
// Load example images asynchronously - merge regular and custom images
|
||||
const regularImages = checkpoint.civitai?.images || [];
|
||||
const customImages = checkpoint.civitai?.customImages || [];
|
||||
// Combine images - regular images first, then custom images
|
||||
const allImages = [...regularImages, ...customImages];
|
||||
loadExampleImages(allImages, checkpoint.sha256);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up editable fields in the checkpoint modal
|
||||
* @param {string} filePath - The full file path of the model.
|
||||
*/
|
||||
function setupEditableFields(filePath) {
|
||||
const editableFields = document.querySelectorAll('.editable-field [contenteditable]');
|
||||
|
||||
editableFields.forEach(field => {
|
||||
field.addEventListener('focus', function() {
|
||||
if (this.textContent === 'Add your notes here...') {
|
||||
this.textContent = '';
|
||||
}
|
||||
});
|
||||
|
||||
field.addEventListener('blur', function() {
|
||||
if (this.textContent.trim() === '') {
|
||||
if (this.classList.contains('notes-content')) {
|
||||
this.textContent = 'Add your notes here...';
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Add keydown event listeners for notes
|
||||
const notesContent = document.querySelector('.notes-content');
|
||||
if (notesContent) {
|
||||
notesContent.addEventListener('keydown', async function(e) {
|
||||
if (e.key === 'Enter') {
|
||||
if (e.shiftKey) {
|
||||
// Allow shift+enter for new line
|
||||
return;
|
||||
}
|
||||
e.preventDefault();
|
||||
await saveNotes(filePath);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save checkpoint notes
|
||||
* @param {string} filePath - Path to the checkpoint file
|
||||
*/
|
||||
async function saveNotes(filePath) {
|
||||
const content = document.querySelector('.notes-content').textContent;
|
||||
try {
|
||||
await saveModelMetadata(filePath, { notes: content });
|
||||
|
||||
showToast('Notes saved successfully', 'success');
|
||||
} catch (error) {
|
||||
showToast('Failed to save notes', 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// Export the checkpoint modal API
|
||||
const checkpointModal = {
|
||||
show: showCheckpointModal,
|
||||
toggleShowcase,
|
||||
scrollToTop
|
||||
};
|
||||
|
||||
export { checkpointModal };
|
||||
|
||||
// Define global functions for use in HTML
|
||||
window.toggleShowcase = function(element) {
|
||||
toggleShowcase(element);
|
||||
};
|
||||
|
||||
window.scrollToTopCheckpoint = function(button) {
|
||||
scrollToTop(button);
|
||||
};
|
||||
|
||||
window.saveCheckpointNotes = function(filePath) {
|
||||
saveNotes(filePath);
|
||||
};
|
||||
@@ -1,8 +1,8 @@
|
||||
// CheckpointsControls.js - Specific implementation for the Checkpoints page
|
||||
import { PageControls } from './PageControls.js';
|
||||
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/baseModelApi.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js';
|
||||
import { downloadManager } from '../../managers/DownloadManager.js';
|
||||
|
||||
/**
|
||||
* CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality
|
||||
@@ -12,9 +12,6 @@ export class CheckpointsControls extends PageControls {
|
||||
// Initialize with 'checkpoints' page type
|
||||
super('checkpoints');
|
||||
|
||||
// Initialize checkpoint download manager
|
||||
this.downloadManager = new CheckpointDownloadManager();
|
||||
|
||||
// Register API methods specific to the Checkpoints page
|
||||
this.registerCheckpointsAPI();
|
||||
}
|
||||
@@ -26,7 +23,7 @@ export class CheckpointsControls extends PageControls {
|
||||
const checkpointsAPI = {
|
||||
// Core API functions
|
||||
loadMoreModels: async (resetPage = false, updateFolders = false) => {
|
||||
return await loadMoreCheckpoints(resetPage, updateFolders);
|
||||
return await getModelApiClient().loadMoreWithVirtualScroll(resetPage, updateFolders);
|
||||
},
|
||||
|
||||
resetAndReload: async (updateFolders = false) => {
|
||||
@@ -34,17 +31,17 @@ export class CheckpointsControls extends PageControls {
|
||||
},
|
||||
|
||||
refreshModels: async (fullRebuild = false) => {
|
||||
return await refreshCheckpoints(fullRebuild);
|
||||
return await getModelApiClient().refreshModels(fullRebuild);
|
||||
},
|
||||
|
||||
// Add fetch from Civitai functionality for checkpoints
|
||||
fetchFromCivitai: async () => {
|
||||
return await fetchCivitai();
|
||||
return await getModelApiClient().fetchCivitaiMetadata();
|
||||
},
|
||||
|
||||
// Add show download modal functionality
|
||||
showDownloadModal: () => {
|
||||
this.downloadManager.showDownloadModal();
|
||||
downloadManager.showDownloadModal();
|
||||
},
|
||||
|
||||
// No clearCustomFilter implementation is needed for checkpoints
|
||||
|
||||
57
static/js/components/controls/EmbeddingsControls.js
Normal file
57
static/js/components/controls/EmbeddingsControls.js
Normal file
@@ -0,0 +1,57 @@
|
||||
// EmbeddingsControls.js - Specific implementation for the Embeddings page
|
||||
import { PageControls } from './PageControls.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/baseModelApi.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { downloadManager } from '../../managers/DownloadManager.js';
|
||||
|
||||
/**
|
||||
* EmbeddingsControls class - Extends PageControls for Embedding-specific functionality
|
||||
*/
|
||||
export class EmbeddingsControls extends PageControls {
|
||||
constructor() {
|
||||
// Initialize with 'embeddings' page type
|
||||
super('embeddings');
|
||||
|
||||
// Register API methods specific to the Embeddings page
|
||||
this.registerEmbeddingsAPI();
|
||||
}
|
||||
|
||||
/**
|
||||
* Register Embedding-specific API methods
|
||||
*/
|
||||
registerEmbeddingsAPI() {
|
||||
const embeddingsAPI = {
|
||||
// Core API functions
|
||||
loadMoreModels: async (resetPage = false, updateFolders = false) => {
|
||||
return await getModelApiClient().loadMoreWithVirtualScroll(resetPage, updateFolders);
|
||||
},
|
||||
|
||||
resetAndReload: async (updateFolders = false) => {
|
||||
return await resetAndReload(updateFolders);
|
||||
},
|
||||
|
||||
refreshModels: async (fullRebuild = false) => {
|
||||
return await getModelApiClient().refreshModels(fullRebuild);
|
||||
},
|
||||
|
||||
// Add fetch from Civitai functionality for embeddings
|
||||
fetchFromCivitai: async () => {
|
||||
return await getModelApiClient().fetchCivitaiMetadata();
|
||||
},
|
||||
|
||||
// Add show download modal functionality
|
||||
showDownloadModal: () => {
|
||||
downloadManager.showDownloadModal();
|
||||
},
|
||||
|
||||
// No clearCustomFilter implementation is needed for embeddings
|
||||
// as custom filters are currently only used for LoRAs
|
||||
clearCustomFilter: async () => {
|
||||
showToast('No custom filter to clear', 'info');
|
||||
}
|
||||
};
|
||||
|
||||
// Register the API
|
||||
this.registerAPI(embeddingsAPI);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
// LorasControls.js - Specific implementation for the LoRAs page
|
||||
import { PageControls } from './PageControls.js';
|
||||
import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/baseModelApi.js';
|
||||
import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
||||
import { createAlphabetBar } from '../alphabet/index.js';
|
||||
import { downloadManager } from '../../managers/DownloadManager.js';
|
||||
|
||||
/**
|
||||
* LorasControls class - Extends PageControls for LoRA-specific functionality
|
||||
@@ -29,7 +30,7 @@ export class LorasControls extends PageControls {
|
||||
const lorasAPI = {
|
||||
// Core API functions
|
||||
loadMoreModels: async (resetPage = false, updateFolders = false) => {
|
||||
return await loadMoreLoras(resetPage, updateFolders);
|
||||
return await getModelApiClient().loadMoreWithVirtualScroll(resetPage, updateFolders);
|
||||
},
|
||||
|
||||
resetAndReload: async (updateFolders = false) => {
|
||||
@@ -37,20 +38,16 @@ export class LorasControls extends PageControls {
|
||||
},
|
||||
|
||||
refreshModels: async (fullRebuild = false) => {
|
||||
return await refreshLoras(fullRebuild);
|
||||
return await getModelApiClient().refreshModels(fullRebuild);
|
||||
},
|
||||
|
||||
// LoRA-specific API functions
|
||||
fetchFromCivitai: async () => {
|
||||
return await fetchCivitai();
|
||||
return await getModelApiClient().fetchCivitaiMetadata();
|
||||
},
|
||||
|
||||
showDownloadModal: () => {
|
||||
if (window.downloadManager) {
|
||||
window.downloadManager.showDownloadModal();
|
||||
} else {
|
||||
console.error('Download manager not available');
|
||||
}
|
||||
downloadManager.showDownloadModal();
|
||||
},
|
||||
|
||||
toggleBulkMode: () => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
|
||||
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||
import { getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
|
||||
@@ -41,6 +41,9 @@ export class PageControls {
|
||||
this.pageState.isLoading = false;
|
||||
this.pageState.hasMore = true;
|
||||
|
||||
// Set default sort based on page type
|
||||
this.pageState.sortBy = this.pageType === 'loras' ? 'name:asc' : 'name:asc';
|
||||
|
||||
// Load sort preference
|
||||
this.loadSortPreference();
|
||||
}
|
||||
@@ -239,7 +242,7 @@ export class PageControls {
|
||||
* @param {string} folderPath - Folder path to filter by
|
||||
*/
|
||||
filterByFolder(folderPath) {
|
||||
const cardSelector = this.pageType === 'loras' ? '.lora-card' : '.checkpoint-card';
|
||||
const cardSelector = this.pageType === 'loras' ? '.model-card' : '.checkpoint-card';
|
||||
document.querySelectorAll(cardSelector).forEach(card => {
|
||||
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
|
||||
});
|
||||
@@ -326,14 +329,36 @@ export class PageControls {
|
||||
loadSortPreference() {
|
||||
const savedSort = getStorageItem(`${this.pageType}_sort`);
|
||||
if (savedSort) {
|
||||
this.pageState.sortBy = savedSort;
|
||||
// Handle legacy format conversion
|
||||
const convertedSort = this.convertLegacySortFormat(savedSort);
|
||||
this.pageState.sortBy = convertedSort;
|
||||
const sortSelect = document.getElementById('sortSelect');
|
||||
if (sortSelect) {
|
||||
sortSelect.value = savedSort;
|
||||
sortSelect.value = convertedSort;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert legacy sort format to new format
|
||||
* @param {string} sortValue - The sort value to convert
|
||||
* @returns {string} - Converted sort value
|
||||
*/
|
||||
convertLegacySortFormat(sortValue) {
|
||||
// Convert old format to new format with direction
|
||||
switch (sortValue) {
|
||||
case 'name':
|
||||
return 'name:asc';
|
||||
case 'date':
|
||||
return 'date:desc'; // Newest first is more intuitive default
|
||||
case 'size':
|
||||
return 'size:desc'; // Largest first is more intuitive default
|
||||
default:
|
||||
// If it's already in new format or unknown, return as is
|
||||
return sortValue.includes(':') ? sortValue : 'name:asc';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save sort preference to storage
|
||||
* @param {string} sortValue - The sort value to save
|
||||
@@ -349,7 +374,7 @@ export class PageControls {
|
||||
openCivitai(modelName) {
|
||||
// Get card selector based on page type
|
||||
const cardSelector = this.pageType === 'loras'
|
||||
? `.lora-card[data-name="${modelName}"]`
|
||||
? `.model-card[data-name="${modelName}"]`
|
||||
: `.checkpoint-card[data-name="${modelName}"]`;
|
||||
|
||||
const card = document.querySelector(cardSelector);
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
import { PageControls } from './PageControls.js';
|
||||
import { LorasControls } from './LorasControls.js';
|
||||
import { CheckpointsControls } from './CheckpointsControls.js';
|
||||
import { EmbeddingsControls } from './EmbeddingsControls.js';
|
||||
|
||||
// Export the classes
|
||||
export { PageControls, LorasControls, CheckpointsControls };
|
||||
export { PageControls, LorasControls, CheckpointsControls, EmbeddingsControls };
|
||||
|
||||
/**
|
||||
* Factory function to create the appropriate controls based on page type
|
||||
* @param {string} pageType - The type of page ('loras' or 'checkpoints')
|
||||
* @param {string} pageType - The type of page ('loras', 'checkpoints', or 'embeddings')
|
||||
* @returns {PageControls} - The appropriate controls instance
|
||||
*/
|
||||
export function createPageControls(pageType) {
|
||||
@@ -16,6 +17,8 @@ export function createPageControls(pageType) {
|
||||
return new LorasControls();
|
||||
} else if (pageType === 'checkpoints') {
|
||||
return new CheckpointsControls();
|
||||
} else if (pageType === 'embeddings') {
|
||||
return new EmbeddingsControls();
|
||||
} else {
|
||||
console.error(`Unknown page type: ${pageType}`);
|
||||
return null;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user