mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd83f7520e | ||
|
|
b9a4e7a09b | ||
|
|
c30e57ede8 | ||
|
|
0dba1b336d | ||
|
|
820afe9319 | ||
|
|
5a97f4bc75 | ||
|
|
94da404cc5 | ||
|
|
1da476d858 | ||
|
|
1daaff6bd4 | ||
|
|
e252e44403 | ||
|
|
778ad8abd2 | ||
|
|
68cf381b50 | ||
|
|
337f73e711 | ||
|
|
04ba966a6e | ||
|
|
71c8cf84e0 | ||
|
|
db1aec94e5 | ||
|
|
553e1868e1 | ||
|
|
938ceb49b2 | ||
|
|
c0f03b79a8 | ||
|
|
a492638133 | ||
|
|
e17d6c8ebf | ||
|
|
ffcfe5ea3e | ||
|
|
719e18adb6 | ||
|
|
92d471daf5 | ||
|
|
66babf9ee1 | ||
|
|
60df2df324 |
201
.agents/skills/lora-manager-e2e/SKILL.md
Normal file
201
.agents/skills/lora-manager-e2e/SKILL.md
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
---
|
||||||
|
name: lora-manager-e2e
|
||||||
|
description: End-to-end testing and validation for LoRa Manager features. Use when performing automated E2E validation of LoRa Manager standalone mode, including starting/restarting the server, using Chrome DevTools MCP to interact with the web UI at http://127.0.0.1:8188/loras, and verifying frontend-to-backend functionality. Covers workflow validation, UI interaction testing, and integration testing between the standalone Python backend and the browser frontend.
|
||||||
|
---
|
||||||
|
|
||||||
|
# LoRa Manager E2E Testing
|
||||||
|
|
||||||
|
This skill provides workflows and utilities for end-to-end testing of LoRa Manager using Chrome DevTools MCP.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- LoRa Manager project cloned and dependencies installed (`pip install -r requirements.txt`)
|
||||||
|
- Chrome browser available for debugging
|
||||||
|
- Chrome DevTools MCP connected
|
||||||
|
|
||||||
|
## Quick Start Workflow
|
||||||
|
|
||||||
|
### 1. Start LoRa Manager Standalone
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use the provided script to start the server
|
||||||
|
python .agents/skills/lora-manager-e2e/scripts/start_server.py --port 8188
|
||||||
|
```
|
||||||
|
|
||||||
|
Or manually:
|
||||||
|
```bash
|
||||||
|
cd /home/miao/workspace/ComfyUI/custom_nodes/ComfyUI-Lora-Manager
|
||||||
|
python standalone.py --port 8188
|
||||||
|
```
|
||||||
|
|
||||||
|
Wait for server ready message before proceeding.
|
||||||
|
|
||||||
|
### 2. Open Chrome Debug Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Chrome with remote debugging on port 9222
|
||||||
|
google-chrome --remote-debugging-port=9222 --user-data-dir=/tmp/chrome-lora-manager http://127.0.0.1:8188/loras
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Connect Chrome DevTools MCP
|
||||||
|
|
||||||
|
Ensure the MCP server is connected to Chrome at `http://localhost:9222`.
|
||||||
|
|
||||||
|
### 4. Navigate and Interact
|
||||||
|
|
||||||
|
Use Chrome DevTools MCP tools to:
|
||||||
|
- Take snapshots: `take_snapshot`
|
||||||
|
- Click elements: `click`
|
||||||
|
- Fill forms: `fill` or `fill_form`
|
||||||
|
- Evaluate scripts: `evaluate_script`
|
||||||
|
- Wait for elements: `wait_for`
|
||||||
|
|
||||||
|
## Common E2E Test Patterns
|
||||||
|
|
||||||
|
### Pattern: Full Page Load Verification
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Navigate to LoRA list page
|
||||||
|
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||||
|
|
||||||
|
# Wait for page to load
|
||||||
|
wait_for(text="LoRAs", timeout=10000)
|
||||||
|
|
||||||
|
# Take snapshot to verify UI state
|
||||||
|
snapshot = take_snapshot()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern: Restart Server for Configuration Changes
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Stop current server (if running)
|
||||||
|
# Start with new configuration
|
||||||
|
python .agents/skills/lora-manager-e2e/scripts/start_server.py --port 8188 --restart
|
||||||
|
|
||||||
|
# Wait and refresh browser
|
||||||
|
navigate_page(type="reload", ignoreCache=True)
|
||||||
|
wait_for(text="LoRAs", timeout=15000)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern: Verify Backend API via Frontend
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Execute script in browser to call backend API
|
||||||
|
result = evaluate_script(function="""
|
||||||
|
async () => {
|
||||||
|
const response = await fetch('/loras/api/list');
|
||||||
|
const data = await response.json();
|
||||||
|
return { count: data.length, firstItem: data[0]?.name };
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern: Form Submission Flow
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Fill a form (e.g., search or filter)
|
||||||
|
fill_form(elements=[
|
||||||
|
{"uid": "search-input", "value": "character"},
|
||||||
|
])
|
||||||
|
|
||||||
|
# Click submit button
|
||||||
|
click(uid="search-button")
|
||||||
|
|
||||||
|
# Wait for results
|
||||||
|
wait_for(text="Results", timeout=5000)
|
||||||
|
|
||||||
|
# Verify results via snapshot
|
||||||
|
snapshot = take_snapshot()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pattern: Modal Dialog Interaction
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Open modal (e.g., add LoRA)
|
||||||
|
click(uid="add-lora-button")
|
||||||
|
|
||||||
|
# Wait for modal to appear
|
||||||
|
wait_for(text="Add LoRA", timeout=3000)
|
||||||
|
|
||||||
|
# Fill modal form
|
||||||
|
fill_form(elements=[
|
||||||
|
{"uid": "lora-name", "value": "Test LoRA"},
|
||||||
|
{"uid": "lora-path", "value": "/path/to/lora.safetensors"},
|
||||||
|
])
|
||||||
|
|
||||||
|
# Submit
|
||||||
|
click(uid="modal-submit-button")
|
||||||
|
|
||||||
|
# Wait for success message or close
|
||||||
|
wait_for(text="Success", timeout=5000)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Scripts
|
||||||
|
|
||||||
|
### scripts/start_server.py
|
||||||
|
|
||||||
|
Starts or restarts the LoRa Manager standalone server.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/start_server.py [--port PORT] [--restart] [--wait]
|
||||||
|
```
|
||||||
|
|
||||||
|
Options:
|
||||||
|
- `--port`: Server port (default: 8188)
|
||||||
|
- `--restart`: Kill existing server before starting
|
||||||
|
- `--wait`: Wait for server to be ready before exiting
|
||||||
|
|
||||||
|
### scripts/wait_for_server.py
|
||||||
|
|
||||||
|
Polls server until ready or timeout.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/wait_for_server.py [--port PORT] [--timeout SECONDS]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Scenarios Reference
|
||||||
|
|
||||||
|
See [references/test-scenarios.md](references/test-scenarios.md) for detailed test scenarios including:
|
||||||
|
- LoRA list display and filtering
|
||||||
|
- Model metadata editing
|
||||||
|
- Recipe creation and management
|
||||||
|
- Settings configuration
|
||||||
|
- Import/export functionality
|
||||||
|
|
||||||
|
## Network Request Verification
|
||||||
|
|
||||||
|
Use `list_network_requests` and `get_network_request` to verify API calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# List recent XHR/fetch requests
|
||||||
|
requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||||
|
|
||||||
|
# Get details of specific request
|
||||||
|
details = get_network_request(reqid=123)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Console Message Monitoring
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Check for errors or warnings
|
||||||
|
messages = list_console_messages(types=["error", "warn"])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Testing
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Start performance trace
|
||||||
|
performance_start_trace(reload=True, autoStop=False)
|
||||||
|
|
||||||
|
# Perform actions...
|
||||||
|
|
||||||
|
# Stop and analyze
|
||||||
|
results = performance_stop_trace()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cleanup
|
||||||
|
|
||||||
|
Always ensure proper cleanup after tests:
|
||||||
|
1. Stop the standalone server
|
||||||
|
2. Close browser pages (keep at least one open)
|
||||||
|
3. Clear temporary data if needed
|
||||||
324
.agents/skills/lora-manager-e2e/references/mcp-cheatsheet.md
Normal file
324
.agents/skills/lora-manager-e2e/references/mcp-cheatsheet.md
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
# Chrome DevTools MCP Cheatsheet for LoRa Manager
|
||||||
|
|
||||||
|
Quick reference for common MCP commands used in LoRa Manager E2E testing.
|
||||||
|
|
||||||
|
## Navigation
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Navigate to LoRA list page
|
||||||
|
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||||
|
|
||||||
|
# Reload page with cache clear
|
||||||
|
navigate_page(type="reload", ignoreCache=True)
|
||||||
|
|
||||||
|
# Go back/forward
|
||||||
|
navigate_page(type="back")
|
||||||
|
navigate_page(type="forward")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Waiting
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Wait for text to appear
|
||||||
|
wait_for(text="LoRAs", timeout=10000)
|
||||||
|
|
||||||
|
# Wait for specific element (via evaluate_script)
|
||||||
|
evaluate_script(function="""
|
||||||
|
() => {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const check = () => {
|
||||||
|
if (document.querySelector('.lora-card')) {
|
||||||
|
resolve(true);
|
||||||
|
} else {
|
||||||
|
setTimeout(check, 100);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
check();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Taking Snapshots
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Full page snapshot
|
||||||
|
snapshot = take_snapshot()
|
||||||
|
|
||||||
|
# Verbose snapshot (more details)
|
||||||
|
snapshot = take_snapshot(verbose=True)
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
take_snapshot(filePath="test-snapshots/page-load.json")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Element Interaction
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Click element
|
||||||
|
click(uid="element-uid-from-snapshot")
|
||||||
|
|
||||||
|
# Double click
|
||||||
|
click(uid="element-uid", dblClick=True)
|
||||||
|
|
||||||
|
# Fill input
|
||||||
|
fill(uid="search-input", value="test query")
|
||||||
|
|
||||||
|
# Fill multiple inputs
|
||||||
|
fill_form(elements=[
|
||||||
|
{"uid": "input-1", "value": "value 1"},
|
||||||
|
{"uid": "input-2", "value": "value 2"},
|
||||||
|
])
|
||||||
|
|
||||||
|
# Hover
|
||||||
|
hover(uid="lora-card-1")
|
||||||
|
|
||||||
|
# Upload file
|
||||||
|
upload_file(uid="file-input", filePath="/path/to/file.safetensors")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Keyboard Input
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Press key
|
||||||
|
press_key(key="Enter")
|
||||||
|
press_key(key="Escape")
|
||||||
|
press_key(key="Tab")
|
||||||
|
|
||||||
|
# Keyboard shortcuts
|
||||||
|
press_key(key="Control+A") # Select all
|
||||||
|
press_key(key="Control+F") # Find
|
||||||
|
```
|
||||||
|
|
||||||
|
## JavaScript Evaluation
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Simple evaluation
|
||||||
|
result = evaluate_script(function="() => document.title")
|
||||||
|
|
||||||
|
# Async evaluation
|
||||||
|
result = evaluate_script(function="""
|
||||||
|
async () => {
|
||||||
|
const response = await fetch('/loras/api/list');
|
||||||
|
return await response.json();
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Check element existence
|
||||||
|
exists = evaluate_script(function="""
|
||||||
|
() => document.querySelector('.lora-card') !== null
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Get element count
|
||||||
|
count = evaluate_script(function="""
|
||||||
|
() => document.querySelectorAll('.lora-card').length
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Network Monitoring
|
||||||
|
|
||||||
|
```python
|
||||||
|
# List all network requests
|
||||||
|
requests = list_network_requests()
|
||||||
|
|
||||||
|
# Filter by resource type
|
||||||
|
xhr_requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||||
|
|
||||||
|
# Get specific request details
|
||||||
|
details = get_network_request(reqid=123)
|
||||||
|
|
||||||
|
# Include preserved requests from previous navigations
|
||||||
|
all_requests = list_network_requests(includePreservedRequests=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Console Monitoring
|
||||||
|
|
||||||
|
```python
|
||||||
|
# List all console messages
|
||||||
|
messages = list_console_messages()
|
||||||
|
|
||||||
|
# Filter by type
|
||||||
|
errors = list_console_messages(types=["error", "warn"])
|
||||||
|
|
||||||
|
# Include preserved messages
|
||||||
|
all_messages = list_console_messages(includePreservedMessages=True)
|
||||||
|
|
||||||
|
# Get specific message
|
||||||
|
details = get_console_message(msgid=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Testing
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Start trace with page reload
|
||||||
|
performance_start_trace(reload=True, autoStop=False)
|
||||||
|
|
||||||
|
# Start trace without reload
|
||||||
|
performance_start_trace(reload=False, autoStop=True, filePath="trace.json.gz")
|
||||||
|
|
||||||
|
# Stop trace
|
||||||
|
results = performance_stop_trace()
|
||||||
|
|
||||||
|
# Stop and save
|
||||||
|
performance_stop_trace(filePath="trace-results.json.gz")
|
||||||
|
|
||||||
|
# Analyze specific insight
|
||||||
|
insight = performance_analyze_insight(
|
||||||
|
insightSetId="results.insightSets[0].id",
|
||||||
|
insightName="LCPBreakdown"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Page Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# List open pages
|
||||||
|
pages = list_pages()
|
||||||
|
|
||||||
|
# Select a page
|
||||||
|
select_page(pageId=0, bringToFront=True)
|
||||||
|
|
||||||
|
# Create new page
|
||||||
|
new_page(url="http://127.0.0.1:8188/loras")
|
||||||
|
|
||||||
|
# Close page (keep at least one open!)
|
||||||
|
close_page(pageId=1)
|
||||||
|
|
||||||
|
# Resize page
|
||||||
|
resize_page(width=1920, height=1080)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Screenshots
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Full page screenshot
|
||||||
|
take_screenshot(fullPage=True)
|
||||||
|
|
||||||
|
# Viewport screenshot
|
||||||
|
take_screenshot()
|
||||||
|
|
||||||
|
# Element screenshot
|
||||||
|
take_screenshot(uid="lora-card-1")
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
take_screenshot(filePath="screenshots/page.png", format="png")
|
||||||
|
|
||||||
|
# JPEG with quality
|
||||||
|
take_screenshot(filePath="screenshots/page.jpg", format="jpeg", quality=90)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dialog Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Accept dialog
|
||||||
|
handle_dialog(action="accept")
|
||||||
|
|
||||||
|
# Accept with text input
|
||||||
|
handle_dialog(action="accept", promptText="user input")
|
||||||
|
|
||||||
|
# Dismiss dialog
|
||||||
|
handle_dialog(action="dismiss")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Device Emulation
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Mobile viewport
|
||||||
|
emulate(viewport={"width": 375, "height": 667, "isMobile": True, "hasTouch": True})
|
||||||
|
|
||||||
|
# Tablet viewport
|
||||||
|
emulate(viewport={"width": 768, "height": 1024, "isMobile": True, "hasTouch": True})
|
||||||
|
|
||||||
|
# Desktop viewport
|
||||||
|
emulate(viewport={"width": 1920, "height": 1080})
|
||||||
|
|
||||||
|
# Network throttling
|
||||||
|
emulate(networkConditions="Slow 3G")
|
||||||
|
emulate(networkConditions="Fast 4G")
|
||||||
|
|
||||||
|
# CPU throttling
|
||||||
|
emulate(cpuThrottlingRate=4) # 4x slowdown
|
||||||
|
|
||||||
|
# Geolocation
|
||||||
|
emulate(geolocation={"latitude": 37.7749, "longitude": -122.4194})
|
||||||
|
|
||||||
|
# User agent
|
||||||
|
emulate(userAgent="Mozilla/5.0 (Custom)")
|
||||||
|
|
||||||
|
# Reset emulation
|
||||||
|
emulate(viewport=None, networkConditions="No emulation", userAgent=None)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Drag and Drop
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Drag element to another
|
||||||
|
drag(from_uid="draggable-item", to_uid="drop-zone")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common LoRa Manager Test Patterns
|
||||||
|
|
||||||
|
### Verify LoRA Cards Loaded
|
||||||
|
|
||||||
|
```python
|
||||||
|
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||||
|
wait_for(text="LoRAs", timeout=10000)
|
||||||
|
|
||||||
|
# Check if cards loaded
|
||||||
|
result = evaluate_script(function="""
|
||||||
|
() => {
|
||||||
|
const cards = document.querySelectorAll('.lora-card');
|
||||||
|
return {
|
||||||
|
count: cards.length,
|
||||||
|
hasData: cards.length > 0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Search and Verify Results
|
||||||
|
|
||||||
|
```python
|
||||||
|
fill(uid="search-input", value="character")
|
||||||
|
press_key(key="Enter")
|
||||||
|
wait_for(timeout=2000) # Wait for debounce
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
result = evaluate_script(function="""
|
||||||
|
() => {
|
||||||
|
const cards = document.querySelectorAll('.lora-card');
|
||||||
|
const names = Array.from(cards).map(c => c.dataset.name || c.textContent);
|
||||||
|
return { count: cards.length, names };
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check API Response
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Trigger API call
|
||||||
|
evaluate_script(function="""
|
||||||
|
() => window.loraApiCallPromise = fetch('/loras/api/list').then(r => r.json())
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Wait and get result
|
||||||
|
import time
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
result = evaluate_script(function="""
|
||||||
|
async () => await window.loraApiCallPromise
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Monitor Console for Errors
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before test: clear console (navigate reloads)
|
||||||
|
navigate_page(type="reload")
|
||||||
|
|
||||||
|
# ... perform actions ...
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
errors = list_console_messages(types=["error"])
|
||||||
|
assert len(errors) == 0, f"Console errors: {errors}"
|
||||||
|
```
|
||||||
272
.agents/skills/lora-manager-e2e/references/test-scenarios.md
Normal file
272
.agents/skills/lora-manager-e2e/references/test-scenarios.md
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
# LoRa Manager E2E Test Scenarios
|
||||||
|
|
||||||
|
This document provides detailed test scenarios for end-to-end validation of LoRa Manager features.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. [LoRA List Page](#lora-list-page)
|
||||||
|
2. [Model Details](#model-details)
|
||||||
|
3. [Recipes](#recipes)
|
||||||
|
4. [Settings](#settings)
|
||||||
|
5. [Import/Export](#importexport)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## LoRA List Page
|
||||||
|
|
||||||
|
### Scenario: Page Load and Display
|
||||||
|
|
||||||
|
**Objective**: Verify the LoRA list page loads correctly and displays models.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to `http://127.0.0.1:8188/loras`
|
||||||
|
2. Wait for page title "LoRAs" to appear
|
||||||
|
3. Take snapshot to verify:
|
||||||
|
- Header with "LoRAs" title is visible
|
||||||
|
- Search/filter controls are present
|
||||||
|
- Grid/list view toggle exists
|
||||||
|
- LoRA cards are displayed (if models exist)
|
||||||
|
- Pagination controls (if applicable)
|
||||||
|
|
||||||
|
**Expected Result**: Page loads without errors, UI elements are present.
|
||||||
|
|
||||||
|
### Scenario: Search Functionality
|
||||||
|
|
||||||
|
**Objective**: Verify search filters LoRA models correctly.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Ensure at least one LoRA exists with known name (e.g., "test-character")
|
||||||
|
2. Navigate to LoRA list page
|
||||||
|
3. Enter search term in search box: "test"
|
||||||
|
4. Press Enter or click search button
|
||||||
|
5. Wait for results to update
|
||||||
|
|
||||||
|
**Expected Result**: Only LoRAs matching search term are displayed.
|
||||||
|
|
||||||
|
**Verification Script**:
|
||||||
|
```python
|
||||||
|
# After search, verify filtered results
|
||||||
|
evaluate_script(function="""
|
||||||
|
() => {
|
||||||
|
const cards = document.querySelectorAll('.lora-card');
|
||||||
|
const names = Array.from(cards).map(c => c.dataset.name);
|
||||||
|
return { count: cards.length, names };
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scenario: Filter by Tags
|
||||||
|
|
||||||
|
**Objective**: Verify tag filtering works correctly.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to LoRA list page
|
||||||
|
2. Click on a tag (e.g., "character", "style")
|
||||||
|
3. Wait for filtered results
|
||||||
|
|
||||||
|
**Expected Result**: Only LoRAs with selected tag are displayed.
|
||||||
|
|
||||||
|
### Scenario: View Mode Toggle
|
||||||
|
|
||||||
|
**Objective**: Verify grid/list view toggle works.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to LoRA list page
|
||||||
|
2. Click list view button
|
||||||
|
3. Verify list layout
|
||||||
|
4. Click grid view button
|
||||||
|
5. Verify grid layout
|
||||||
|
|
||||||
|
**Expected Result**: View mode changes correctly, layout updates.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Details
|
||||||
|
|
||||||
|
### Scenario: Open Model Details
|
||||||
|
|
||||||
|
**Objective**: Verify clicking a LoRA opens its details.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to LoRA list page
|
||||||
|
2. Click on a LoRA card
|
||||||
|
3. Wait for details panel/modal to open
|
||||||
|
|
||||||
|
**Expected Result**: Details panel shows:
|
||||||
|
- Model name
|
||||||
|
- Preview image
|
||||||
|
- Metadata (trigger words, tags, etc.)
|
||||||
|
- Action buttons (edit, delete, etc.)
|
||||||
|
|
||||||
|
### Scenario: Edit Model Metadata
|
||||||
|
|
||||||
|
**Objective**: Verify metadata editing works end-to-end.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Open a LoRA's details
|
||||||
|
2. Click "Edit" button
|
||||||
|
3. Modify trigger words field
|
||||||
|
4. Add/remove tags
|
||||||
|
5. Save changes
|
||||||
|
6. Refresh page
|
||||||
|
7. Reopen the same LoRA
|
||||||
|
|
||||||
|
**Expected Result**: Changes persist after refresh.
|
||||||
|
|
||||||
|
### Scenario: Delete Model
|
||||||
|
|
||||||
|
**Objective**: Verify model deletion works.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Open a LoRA's details
|
||||||
|
2. Click "Delete" button
|
||||||
|
3. Confirm deletion in dialog
|
||||||
|
4. Wait for removal
|
||||||
|
|
||||||
|
**Expected Result**: Model removed from list, success message shown.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Recipes
|
||||||
|
|
||||||
|
### Scenario: Recipe List Display
|
||||||
|
|
||||||
|
**Objective**: Verify recipes page loads and displays recipes.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to `http://127.0.0.1:8188/recipes`
|
||||||
|
2. Wait for "Recipes" title
|
||||||
|
3. Take snapshot
|
||||||
|
|
||||||
|
**Expected Result**: Recipe list displayed with cards/items.
|
||||||
|
|
||||||
|
### Scenario: Create New Recipe
|
||||||
|
|
||||||
|
**Objective**: Verify recipe creation workflow.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to recipes page
|
||||||
|
2. Click "New Recipe" button
|
||||||
|
3. Fill recipe form:
|
||||||
|
- Name: "Test Recipe"
|
||||||
|
- Description: "E2E test recipe"
|
||||||
|
- Add LoRA models
|
||||||
|
4. Save recipe
|
||||||
|
5. Verify recipe appears in list
|
||||||
|
|
||||||
|
**Expected Result**: New recipe created and displayed.
|
||||||
|
|
||||||
|
### Scenario: Apply Recipe
|
||||||
|
|
||||||
|
**Objective**: Verify applying a recipe to ComfyUI.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Open a recipe
|
||||||
|
2. Click "Apply" or "Load in ComfyUI"
|
||||||
|
3. Verify action completes
|
||||||
|
|
||||||
|
**Expected Result**: Recipe applied successfully.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Settings
|
||||||
|
|
||||||
|
### Scenario: Settings Page Load
|
||||||
|
|
||||||
|
**Objective**: Verify settings page displays correctly.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to `http://127.0.0.1:8188/settings`
|
||||||
|
2. Wait for "Settings" title
|
||||||
|
3. Take snapshot
|
||||||
|
|
||||||
|
**Expected Result**: Settings form with various options displayed.
|
||||||
|
|
||||||
|
### Scenario: Change Setting and Restart
|
||||||
|
|
||||||
|
**Objective**: Verify settings persist after restart.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to settings page
|
||||||
|
2. Change a setting (e.g., default view mode)
|
||||||
|
3. Save settings
|
||||||
|
4. Restart server: `python scripts/start_server.py --restart --wait`
|
||||||
|
5. Refresh browser page
|
||||||
|
6. Navigate to settings
|
||||||
|
|
||||||
|
**Expected Result**: Changed setting value persists.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Import/Export
|
||||||
|
|
||||||
|
### Scenario: Export Models List
|
||||||
|
|
||||||
|
**Objective**: Verify export functionality.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Navigate to LoRA list
|
||||||
|
2. Click "Export" button
|
||||||
|
3. Select format (JSON/CSV)
|
||||||
|
4. Download file
|
||||||
|
|
||||||
|
**Expected Result**: File downloaded with correct data.
|
||||||
|
|
||||||
|
### Scenario: Import Models
|
||||||
|
|
||||||
|
**Objective**: Verify import functionality.
|
||||||
|
|
||||||
|
**Steps**:
|
||||||
|
1. Prepare import file
|
||||||
|
2. Navigate to import page
|
||||||
|
3. Upload file
|
||||||
|
4. Verify import results
|
||||||
|
|
||||||
|
**Expected Result**: Models imported successfully, confirmation shown.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Integration Tests
|
||||||
|
|
||||||
|
### Scenario: Verify API Endpoints
|
||||||
|
|
||||||
|
**Objective**: Verify backend API responds correctly.
|
||||||
|
|
||||||
|
**Test via browser console**:
|
||||||
|
```javascript
|
||||||
|
// List LoRAs
|
||||||
|
fetch('/loras/api/list').then(r => r.json()).then(console.log)
|
||||||
|
|
||||||
|
// Get LoRA details
|
||||||
|
fetch('/loras/api/detail/<id>').then(r => r.json()).then(console.log)
|
||||||
|
|
||||||
|
// Search LoRAs
|
||||||
|
fetch('/loras/api/search?q=test').then(r => r.json()).then(console.log)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Result**: APIs return valid JSON with expected structure.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Console Error Monitoring
|
||||||
|
|
||||||
|
During all tests, monitor browser console for errors:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Check for JavaScript errors
|
||||||
|
messages = list_console_messages(types=["error"])
|
||||||
|
assert len(messages) == 0, f"Console errors found: {messages}"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Network Request Verification
|
||||||
|
|
||||||
|
Verify key API calls are made:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# List XHR requests
|
||||||
|
requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||||
|
|
||||||
|
# Look for specific endpoints
|
||||||
|
lora_list_requests = [r for r in requests if "/api/list" in r.get("url", "")]
|
||||||
|
assert len(lora_list_requests) > 0, "LoRA list API not called"
|
||||||
|
```
|
||||||
193
.agents/skills/lora-manager-e2e/scripts/example_e2e_test.py
Executable file
193
.agents/skills/lora-manager-e2e/scripts/example_e2e_test.py
Executable file
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example E2E test demonstrating LoRa Manager testing workflow.
|
||||||
|
|
||||||
|
This script shows how to:
|
||||||
|
1. Start the standalone server
|
||||||
|
2. Use Chrome DevTools MCP to interact with the UI
|
||||||
|
3. Verify functionality end-to-end
|
||||||
|
|
||||||
|
Note: This is a template. Actual execution requires Chrome DevTools MCP.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def run_test():
|
||||||
|
"""Run example E2E test flow."""
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("LoRa Manager E2E Test Example")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Step 1: Start server
|
||||||
|
print("\n[1/5] Starting LoRa Manager standalone server...")
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, "start_server.py", "--port", "8188", "--wait", "--timeout", "30"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"Failed to start server: {result.stderr}")
|
||||||
|
return 1
|
||||||
|
print("Server ready!")
|
||||||
|
|
||||||
|
# Step 2: Open Chrome (manual step - show command)
|
||||||
|
print("\n[2/5] Open Chrome with debug mode:")
|
||||||
|
print("google-chrome --remote-debugging-port=9222 --user-data-dir=/tmp/chrome-lora-manager http://127.0.0.1:8188/loras")
|
||||||
|
print("(In actual test, this would be automated via MCP)")
|
||||||
|
|
||||||
|
# Step 3: Navigate and verify page load
|
||||||
|
print("\n[3/5] Page Load Verification:")
|
||||||
|
print("""
|
||||||
|
MCP Commands to execute:
|
||||||
|
1. navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||||
|
2. wait_for(text="LoRAs", timeout=10000)
|
||||||
|
3. snapshot = take_snapshot()
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Step 4: Test search functionality
|
||||||
|
print("\n[4/5] Search Functionality Test:")
|
||||||
|
print("""
|
||||||
|
MCP Commands to execute:
|
||||||
|
1. fill(uid="search-input", value="test")
|
||||||
|
2. press_key(key="Enter")
|
||||||
|
3. wait_for(text="Results", timeout=5000)
|
||||||
|
4. result = evaluate_script(function="""
|
||||||
|
() => {
|
||||||
|
const cards = document.querySelectorAll('.lora-card');
|
||||||
|
return { count: cards.length };
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Step 5: Verify API
|
||||||
|
print("\n[5/5] API Verification:")
|
||||||
|
print("""
|
||||||
|
MCP Commands to execute:
|
||||||
|
1. api_result = evaluate_script(function="""
|
||||||
|
async () => {
|
||||||
|
const response = await fetch('/loras/api/list');
|
||||||
|
const data = await response.json();
|
||||||
|
return { count: data.length, status: response.status };
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
2. Verify api_result['status'] == 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Test flow completed!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def example_restart_flow():
|
||||||
|
"""Example: Testing configuration change that requires restart."""
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example: Server Restart Flow")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("""
|
||||||
|
Scenario: Change setting and verify after restart
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Navigate to settings page
|
||||||
|
- navigate_page(type="url", url="http://127.0.0.1:8188/settings")
|
||||||
|
|
||||||
|
2. Change a setting (e.g., theme)
|
||||||
|
- fill(uid="theme-select", value="dark")
|
||||||
|
- click(uid="save-settings-button")
|
||||||
|
|
||||||
|
3. Restart server
|
||||||
|
- subprocess.run([python, "start_server.py", "--restart", "--wait"])
|
||||||
|
|
||||||
|
4. Refresh browser
|
||||||
|
- navigate_page(type="reload", ignoreCache=True)
|
||||||
|
- wait_for(text="LoRAs", timeout=15000)
|
||||||
|
|
||||||
|
5. Verify setting persisted
|
||||||
|
- navigate_page(type="url", url="http://127.0.0.1:8188/settings")
|
||||||
|
- theme = evaluate_script(function="() => document.querySelector('#theme-select').value")
|
||||||
|
- assert theme == "dark"
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def example_modal_interaction():
|
||||||
|
"""Example: Testing modal dialog interaction."""
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example: Modal Dialog Interaction")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("""
|
||||||
|
Scenario: Add new LoRA via modal
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Open modal
|
||||||
|
- click(uid="add-lora-button")
|
||||||
|
- wait_for(text="Add LoRA", timeout=3000)
|
||||||
|
|
||||||
|
2. Fill form
|
||||||
|
- fill_form(elements=[
|
||||||
|
{"uid": "lora-name", "value": "Test Character"},
|
||||||
|
{"uid": "lora-path", "value": "/models/test.safetensors"},
|
||||||
|
])
|
||||||
|
|
||||||
|
3. Submit
|
||||||
|
- click(uid="modal-submit-button")
|
||||||
|
|
||||||
|
4. Verify success
|
||||||
|
- wait_for(text="Successfully added", timeout=5000)
|
||||||
|
- snapshot = take_snapshot()
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def example_network_monitoring():
|
||||||
|
"""Example: Network request monitoring."""
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Example: Network Request Monitoring")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print("""
|
||||||
|
Scenario: Verify API calls during user interaction
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Clear network log (implicit on navigation)
|
||||||
|
- navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||||
|
|
||||||
|
2. Perform action that triggers API call
|
||||||
|
- fill(uid="search-input", value="character")
|
||||||
|
- press_key(key="Enter")
|
||||||
|
|
||||||
|
3. List network requests
|
||||||
|
- requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||||
|
|
||||||
|
4. Find search API call
|
||||||
|
- search_requests = [r for r in requests if "/api/search" in r.get("url", "")]
|
||||||
|
- assert len(search_requests) > 0, "Search API was not called"
|
||||||
|
|
||||||
|
5. Get request details
|
||||||
|
- if search_requests:
|
||||||
|
details = get_network_request(reqid=search_requests[0]["reqid"])
|
||||||
|
- Verify request method, response status, etc.
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("LoRa Manager E2E Test Examples\n")
|
||||||
|
print("This script demonstrates E2E testing patterns.\n")
|
||||||
|
print("Note: Actual execution requires Chrome DevTools MCP connection.\n")
|
||||||
|
|
||||||
|
run_test()
|
||||||
|
example_restart_flow()
|
||||||
|
example_modal_interaction()
|
||||||
|
example_network_monitoring()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All examples shown!")
|
||||||
|
print("=" * 60)
|
||||||
169
.agents/skills/lora-manager-e2e/scripts/start_server.py
Executable file
169
.agents/skills/lora-manager-e2e/scripts/start_server.py
Executable file
@@ -0,0 +1,169 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Start or restart LoRa Manager standalone server for E2E testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
import signal
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def find_server_process(port: int) -> list[int]:
|
||||||
|
"""Find PIDs of processes listening on the given port."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["lsof", "-ti", f":{port}"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False
|
||||||
|
)
|
||||||
|
if result.returncode == 0 and result.stdout.strip():
|
||||||
|
return [int(pid) for pid in result.stdout.strip().split("\n") if pid]
|
||||||
|
except FileNotFoundError:
|
||||||
|
# lsof not available, try netstat
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["netstat", "-tlnp"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False
|
||||||
|
)
|
||||||
|
pids = []
|
||||||
|
for line in result.stdout.split("\n"):
|
||||||
|
if f":{port}" in line:
|
||||||
|
parts = line.split()
|
||||||
|
for part in parts:
|
||||||
|
if "/" in part:
|
||||||
|
try:
|
||||||
|
pid = int(part.split("/")[0])
|
||||||
|
pids.append(pid)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return pids
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def kill_server(port: int) -> None:
|
||||||
|
"""Kill processes using the specified port."""
|
||||||
|
pids = find_server_process(port)
|
||||||
|
for pid in pids:
|
||||||
|
try:
|
||||||
|
os.kill(pid, signal.SIGTERM)
|
||||||
|
print(f"Sent SIGTERM to process {pid}")
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait for processes to terminate
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Force kill if still running
|
||||||
|
pids = find_server_process(port)
|
||||||
|
for pid in pids:
|
||||||
|
try:
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
print(f"Sent SIGKILL to process {pid}")
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def is_server_ready(port: int, timeout: float = 0.5) -> bool:
|
||||||
|
"""Check if server is accepting connections."""
|
||||||
|
try:
|
||||||
|
with socket.create_connection(("127.0.0.1", port), timeout=timeout):
|
||||||
|
return True
|
||||||
|
except (socket.timeout, ConnectionRefusedError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_server(port: int, timeout: int = 30) -> bool:
|
||||||
|
"""Wait for server to become ready."""
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
if is_server_ready(port):
|
||||||
|
return True
|
||||||
|
time.sleep(0.5)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Start LoRa Manager standalone server for E2E testing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8188,
|
||||||
|
help="Server port (default: 8188)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--restart",
|
||||||
|
action="store_true",
|
||||||
|
help="Kill existing server before starting"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--wait",
|
||||||
|
action="store_true",
|
||||||
|
help="Wait for server to be ready before exiting"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="Timeout for waiting (default: 30)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Get project root (parent of .agents directory)
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
skill_dir = os.path.dirname(script_dir)
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(skill_dir)))
|
||||||
|
|
||||||
|
# Restart if requested
|
||||||
|
if args.restart:
|
||||||
|
print(f"Killing existing server on port {args.port}...")
|
||||||
|
kill_server(args.port)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Check if already running
|
||||||
|
if is_server_ready(args.port):
|
||||||
|
print(f"Server already running on port {args.port}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
print(f"Starting LoRa Manager standalone server on port {args.port}...")
|
||||||
|
cmd = [sys.executable, "standalone.py", "--port", str(args.port)]
|
||||||
|
|
||||||
|
# Start in background
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
cwd=project_root,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
start_new_session=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Server process started with PID {process.pid}")
|
||||||
|
|
||||||
|
# Wait for ready if requested
|
||||||
|
if args.wait:
|
||||||
|
print(f"Waiting for server to be ready (timeout: {args.timeout}s)...")
|
||||||
|
if wait_for_server(args.port, args.timeout):
|
||||||
|
print(f"Server ready at http://127.0.0.1:{args.port}/loras")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print(f"Timeout waiting for server")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"Server starting at http://127.0.0.1:{args.port}/loras")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
61
.agents/skills/lora-manager-e2e/scripts/wait_for_server.py
Executable file
61
.agents/skills/lora-manager-e2e/scripts/wait_for_server.py
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Wait for LoRa Manager server to become ready.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def is_server_ready(port: int, timeout: float = 0.5) -> bool:
|
||||||
|
"""Check if server is accepting connections."""
|
||||||
|
try:
|
||||||
|
with socket.create_connection(("127.0.0.1", port), timeout=timeout):
|
||||||
|
return True
|
||||||
|
except (socket.timeout, ConnectionRefusedError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_server(port: int, timeout: int = 30) -> bool:
|
||||||
|
"""Wait for server to become ready."""
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
if is_server_ready(port):
|
||||||
|
return True
|
||||||
|
time.sleep(0.5)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Wait for LoRa Manager server to become ready"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8188,
|
||||||
|
help="Server port (default: 8188)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="Timeout in seconds (default: 30)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Waiting for server on port {args.port} (timeout: {args.timeout}s)...")
|
||||||
|
|
||||||
|
if wait_for_server(args.port, args.timeout):
|
||||||
|
print(f"Server ready at http://127.0.0.1:{args.port}/loras")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print(f"Timeout: Server not ready after {args.timeout}s")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -34,6 +34,11 @@ Enhance your Civitai browsing experience with our companion browser extension! S
|
|||||||
|
|
||||||
## Release Notes
|
## Release Notes
|
||||||
|
|
||||||
|
### v0.9.14
|
||||||
|
* **LoRA Cycler Node** - Introduced a new LoRA Cycler node that enables iteration through specified LoRAs with support for repeat count and pause iteration functionality. Refer to the new "Lora Cycler" template workflow for concrete example.
|
||||||
|
* **Enhanced Prompt Node with Tag Autocomplete** - Enhanced the Prompt node with comprehensive tag autocomplete based on merged Danbooru + e621 tags. Supports tag search and autocomplete functionality. Implemented a command system with shortcuts like `/char` or `/artist` for category-specific tag searching. Added `/ac` or `/noac` commands to quickly enable or disable autocomplete. Refer to the "Lora Manager Basic" template workflow in ComfyUI -> Templates -> ComfyUI-Lora-Manager for detailed tips.
|
||||||
|
* **Bug Fixes & Stability** - Addressed multiple bugs and improved overall stability.
|
||||||
|
|
||||||
### v0.9.12
|
### v0.9.12
|
||||||
* **LoRA Randomizer System** - Introduced a comprehensive LoRA randomization system featuring LoRA Pool and LoRA Randomizer nodes for flexible and dynamic generation workflows.
|
* **LoRA Randomizer System** - Introduced a comprehensive LoRA randomization system featuring LoRA Pool and LoRA Randomizer nodes for flexible and dynamic generation workflows.
|
||||||
* **LoRA Randomizer Template** - Refer to the new "LoRA Randomizer" template workflow for detailed examples of flexible randomization modes, lock & reuse options, and other features.
|
* **LoRA Randomizer Template** - Refer to the new "LoRA Randomizer" template workflow for detailed examples of flexible randomization modes, lock & reuse options, and other features.
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -9,9 +9,9 @@
|
|||||||
"back": "Zurück",
|
"back": "Zurück",
|
||||||
"next": "Weiter",
|
"next": "Weiter",
|
||||||
"backToTop": "Nach oben",
|
"backToTop": "Nach oben",
|
||||||
"add": "Hinzufügen",
|
|
||||||
"settings": "Einstellungen",
|
"settings": "Einstellungen",
|
||||||
"help": "Hilfe"
|
"help": "Hilfe",
|
||||||
|
"add": "Hinzufügen"
|
||||||
},
|
},
|
||||||
"status": {
|
"status": {
|
||||||
"loading": "Wird geladen...",
|
"loading": "Wird geladen...",
|
||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "Cache-Korruption erkannt"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "Cache-Probleme erkannt"
|
||||||
|
},
|
||||||
|
"content": "{invalid} von {total} Cache-Einträgen sind ungültig ({rate}). Dies kann zu fehlenden Modellen oder Fehlern führen. Ein Neuaufbau des Caches wird empfohlen.",
|
||||||
|
"rebuildCache": "Cache neu aufbauen",
|
||||||
|
"dismiss": "Verwerfen",
|
||||||
|
"rebuilding": "Cache wird neu aufgebaut...",
|
||||||
|
"rebuildFailed": "Fehler beim Neuaufbau des Caches: {error}",
|
||||||
|
"retry": "Wiederholen"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "Cache Corruption Detected"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "Cache Issues Detected"
|
||||||
|
},
|
||||||
|
"content": "{invalid} of {total} cache entries are invalid ({rate}). This may cause missing models or errors. Rebuilding the cache is recommended.",
|
||||||
|
"rebuildCache": "Rebuild Cache",
|
||||||
|
"dismiss": "Dismiss",
|
||||||
|
"rebuilding": "Rebuilding cache...",
|
||||||
|
"rebuildFailed": "Failed to rebuild cache: {error}",
|
||||||
|
"retry": "Retry"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "Corrupción de caché detectada"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "Problemas de caché detectados"
|
||||||
|
},
|
||||||
|
"content": "{invalid} de {total} entradas de caché son inválidas ({rate}). Esto puede causar modelos faltantes o errores. Se recomienda reconstruir la caché.",
|
||||||
|
"rebuildCache": "Reconstruir caché",
|
||||||
|
"dismiss": "Descartar",
|
||||||
|
"rebuilding": "Reconstruyendo caché...",
|
||||||
|
"rebuildFailed": "Error al reconstruir la caché: {error}",
|
||||||
|
"retry": "Reintentar"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "Corruption du cache détectée"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "Problèmes de cache détectés"
|
||||||
|
},
|
||||||
|
"content": "{invalid} des {total} entrées de cache sont invalides ({rate}). Cela peut provoquer des modèles manquants ou des erreurs. Il est recommandé de reconstruire le cache.",
|
||||||
|
"rebuildCache": "Reconstruire le cache",
|
||||||
|
"dismiss": "Ignorer",
|
||||||
|
"rebuilding": "Reconstruction du cache...",
|
||||||
|
"rebuildFailed": "Échec de la reconstruction du cache : {error}",
|
||||||
|
"retry": "Réessayer"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,9 +9,9 @@
|
|||||||
"back": "חזור",
|
"back": "חזור",
|
||||||
"next": "הבא",
|
"next": "הבא",
|
||||||
"backToTop": "חזור למעלה",
|
"backToTop": "חזור למעלה",
|
||||||
"add": "הוסף",
|
|
||||||
"settings": "הגדרות",
|
"settings": "הגדרות",
|
||||||
"help": "עזרה"
|
"help": "עזרה",
|
||||||
|
"add": "הוסף"
|
||||||
},
|
},
|
||||||
"status": {
|
"status": {
|
||||||
"loading": "טוען...",
|
"loading": "טוען...",
|
||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "זוהתה שחיתות במטמון"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "זוהו בעיות במטמון"
|
||||||
|
},
|
||||||
|
"content": "{invalid} מתוך {total} רשומות מטמון אינן תקינות ({rate}). זה עלול לגרום לדגמים חסרים או לשגיאות. מומלץ לבנות מחדש את המטמון.",
|
||||||
|
"rebuildCache": "בניית מטמון מחדש",
|
||||||
|
"dismiss": "ביטול",
|
||||||
|
"rebuilding": "בונה מחדש את המטמון...",
|
||||||
|
"rebuildFailed": "נכשלה בניית המטמון מחדש: {error}",
|
||||||
|
"retry": "נסה שוב"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "キャッシュの破損が検出されました"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "キャッシュの問題が検出されました"
|
||||||
|
},
|
||||||
|
"content": "{total}個のキャッシュエントリのうち{invalid}個が無効です({rate})。モデルが見つからない原因になったり、エラーが発生する可能性があります。キャッシュの再構築を推奨します。",
|
||||||
|
"rebuildCache": "キャッシュを再構築",
|
||||||
|
"dismiss": "閉じる",
|
||||||
|
"rebuilding": "キャッシュを再構築中...",
|
||||||
|
"rebuildFailed": "キャッシュの再構築に失敗しました: {error}",
|
||||||
|
"retry": "再試行"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "캐시 손상이 감지되었습니다"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "캐시 문제가 감지되었습니다"
|
||||||
|
},
|
||||||
|
"content": "{total}개의 캐시 항목 중 {invalid}개가 유효하지 않습니다 ({rate}). 모델 누락이나 오류가 발생할 수 있습니다. 캐시를 재구축하는 것이 좋습니다.",
|
||||||
|
"rebuildCache": "캐시 재구축",
|
||||||
|
"dismiss": "무시",
|
||||||
|
"rebuilding": "캐시 재구축 중...",
|
||||||
|
"rebuildFailed": "캐시 재구축 실패: {error}",
|
||||||
|
"retry": "다시 시도"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "Обнаружено повреждение кэша"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "Обнаружены проблемы с кэшем"
|
||||||
|
},
|
||||||
|
"content": "{invalid} из {total} записей кэша недействительны ({rate}). Это может привести к отсутствию моделей или ошибкам. Рекомендуется перестроить кэш.",
|
||||||
|
"rebuildCache": "Перестроить кэш",
|
||||||
|
"dismiss": "Отклонить",
|
||||||
|
"rebuilding": "Перестроение кэша...",
|
||||||
|
"rebuildFailed": "Не удалось перестроить кэш: {error}",
|
||||||
|
"retry": "Повторить"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "来爱发电为Lora Manager项目发电,支持项目持续开发的同时,获取浏览器插件验证码,按季支付更优惠!支付宝/微信方便支付。感谢支持!🚀",
|
"content": "来爱发电为Lora Manager项目发电,支持项目持续开发的同时,获取浏览器插件验证码,按季支付更优惠!支付宝/微信方便支付。感谢支持!🚀",
|
||||||
"supportCta": "为LM发电",
|
"supportCta": "为LM发电",
|
||||||
"learnMore": "浏览器插件教程"
|
"learnMore": "浏览器插件教程"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "检测到缓存损坏"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "检测到缓存问题"
|
||||||
|
},
|
||||||
|
"content": "{total} 个缓存条目中有 {invalid} 个无效({rate})。这可能导致模型丢失或错误。建议重建缓存。",
|
||||||
|
"rebuildCache": "重建缓存",
|
||||||
|
"dismiss": "忽略",
|
||||||
|
"rebuilding": "正在重建缓存...",
|
||||||
|
"rebuildFailed": "重建缓存失败:{error}",
|
||||||
|
"retry": "重试"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1572,6 +1572,20 @@
|
|||||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||||
"supportCta": "Support on Ko-fi",
|
"supportCta": "Support on Ko-fi",
|
||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
|
},
|
||||||
|
"cacheHealth": {
|
||||||
|
"corrupted": {
|
||||||
|
"title": "檢測到快取損壞"
|
||||||
|
},
|
||||||
|
"degraded": {
|
||||||
|
"title": "檢測到快取問題"
|
||||||
|
},
|
||||||
|
"content": "{total} 個快取項目中有 {invalid} 個無效({rate})。這可能會導致模型遺失或錯誤。建議重建快取。",
|
||||||
|
"rebuildCache": "重建快取",
|
||||||
|
"dismiss": "關閉",
|
||||||
|
"rebuilding": "重建快取中...",
|
||||||
|
"rebuildFailed": "重建快取失敗:{error}",
|
||||||
|
"retry": "重試"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@
|
|||||||
"private": true,
|
"private": true,
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"test": "vitest run",
|
"test": "npm run test:js && npm run test:vue",
|
||||||
|
"test:js": "vitest run",
|
||||||
|
"test:vue": "cd vue-widgets && npx vitest run",
|
||||||
"test:watch": "vitest",
|
"test:watch": "vitest",
|
||||||
"test:coverage": "node scripts/run_frontend_coverage.js"
|
"test:coverage": "node scripts/run_frontend_coverage.js"
|
||||||
},
|
},
|
||||||
|
|||||||
93
py/config.py
93
py/config.py
@@ -441,82 +441,53 @@ class Config:
|
|||||||
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
|
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
|
||||||
|
|
||||||
def _scan_symbolic_links(self):
|
def _scan_symbolic_links(self):
|
||||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
"""Scan symbolic links in LoRA, Checkpoint, and Embedding root directories.
|
||||||
|
|
||||||
|
Only scans the first level of each root directory to avoid performance
|
||||||
|
issues with large file systems. Detects symlinks and Windows junctions
|
||||||
|
at the root level only (not nested symlinks in subdirectories).
|
||||||
|
"""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Reset mappings before rescanning to avoid stale entries
|
# Reset mappings before rescanning to avoid stale entries
|
||||||
self._path_mappings.clear()
|
self._path_mappings.clear()
|
||||||
self._seed_root_symlink_mappings()
|
self._seed_root_symlink_mappings()
|
||||||
visited_dirs: Set[str] = set()
|
|
||||||
for root in self._symlink_roots():
|
for root in self._symlink_roots():
|
||||||
self._scan_directory_links(root, visited_dirs)
|
self._scan_first_level_symlinks(root)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Symlink scan finished in %.2f ms with %d mappings",
|
"Symlink scan finished in %.2f ms with %d mappings",
|
||||||
(time.perf_counter() - start) * 1000,
|
(time.perf_counter() - start) * 1000,
|
||||||
len(self._path_mappings),
|
len(self._path_mappings),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _scan_directory_links(self, root: str, visited_dirs: Set[str]):
|
def _scan_first_level_symlinks(self, root: str):
|
||||||
"""Iteratively scan directory symlinks to avoid deep recursion."""
|
"""Scan only the first level of a directory for symlinks.
|
||||||
|
|
||||||
|
This avoids traversing the entire directory tree which can be extremely
|
||||||
|
slow for large model collections. Only symlinks directly under the root
|
||||||
|
are detected.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Note: We only use realpath for the initial root if it's not already resolved
|
with os.scandir(root) as it:
|
||||||
# to ensure we have a valid entry point.
|
for entry in it:
|
||||||
root_real = self._normalize_path(os.path.realpath(root))
|
try:
|
||||||
except OSError:
|
# Only detect symlinks including Windows junctions
|
||||||
root_real = self._normalize_path(root)
|
# Skip normal directories to avoid deep traversal
|
||||||
|
if not self._entry_is_symlink(entry):
|
||||||
|
continue
|
||||||
|
|
||||||
if root_real in visited_dirs:
|
# Resolve the symlink target
|
||||||
return
|
target_path = os.path.realpath(entry.path)
|
||||||
|
if not os.path.isdir(target_path):
|
||||||
|
continue
|
||||||
|
|
||||||
visited_dirs.add(root_real)
|
self.add_path_mapping(entry.path, target_path)
|
||||||
# Stack entries: (display_path, real_resolved_path)
|
except Exception as inner_exc:
|
||||||
stack: List[Tuple[str, str]] = [(root, root_real)]
|
logger.debug(
|
||||||
|
"Error processing directory entry %s: %s", entry.path, inner_exc
|
||||||
while stack:
|
)
|
||||||
current_display, current_real = stack.pop()
|
except Exception as e:
|
||||||
try:
|
logger.error(f"Error scanning links in {root}: {e}")
|
||||||
with os.scandir(current_display) as it:
|
|
||||||
for entry in it:
|
|
||||||
try:
|
|
||||||
# 1. Detect symlinks including Windows junctions
|
|
||||||
is_link = self._entry_is_symlink(entry)
|
|
||||||
|
|
||||||
if is_link:
|
|
||||||
# Only resolve realpath when we actually find a link
|
|
||||||
target_path = os.path.realpath(entry.path)
|
|
||||||
if not os.path.isdir(target_path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
normalized_target = self._normalize_path(target_path)
|
|
||||||
self.add_path_mapping(entry.path, target_path)
|
|
||||||
|
|
||||||
if normalized_target in visited_dirs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
visited_dirs.add(normalized_target)
|
|
||||||
stack.append((target_path, normalized_target))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 2. Process normal directories
|
|
||||||
if not entry.is_dir(follow_symlinks=False):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# For normal directories, we avoid realpath() call by
|
|
||||||
# incrementally building the real path relative to current_real.
|
|
||||||
# This is safe because 'entry' is NOT a symlink.
|
|
||||||
entry_real = self._normalize_path(os.path.join(current_real, entry.name))
|
|
||||||
|
|
||||||
if entry_real in visited_dirs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
visited_dirs.add(entry_real)
|
|
||||||
stack.append((entry.path, entry_real))
|
|
||||||
except Exception as inner_exc:
|
|
||||||
logger.debug(
|
|
||||||
"Error processing directory entry %s: %s", entry.path, inner_exc
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning links in {current_display}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
@@ -14,7 +17,7 @@ if not standalone_mode:
|
|||||||
# Initialize registry
|
# Initialize registry
|
||||||
registry = MetadataRegistry()
|
registry = MetadataRegistry()
|
||||||
|
|
||||||
print("ComfyUI Metadata Collector initialized")
|
logger.info("ComfyUI Metadata Collector initialized")
|
||||||
|
|
||||||
def get_metadata(prompt_id=None):
|
def get_metadata(prompt_id=None):
|
||||||
"""Helper function to get metadata from the registry"""
|
"""Helper function to get metadata from the registry"""
|
||||||
@@ -23,7 +26,7 @@ if not standalone_mode:
|
|||||||
else:
|
else:
|
||||||
# Standalone mode - provide dummy implementations
|
# Standalone mode - provide dummy implementations
|
||||||
def init():
|
def init():
|
||||||
print("ComfyUI Metadata Collector disabled in standalone mode")
|
logger.info("ComfyUI Metadata Collector disabled in standalone mode")
|
||||||
|
|
||||||
def get_metadata(prompt_id=None):
|
def get_metadata(prompt_id=None):
|
||||||
"""Dummy implementation for standalone mode"""
|
"""Dummy implementation for standalone mode"""
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from .metadata_registry import MetadataRegistry
|
from .metadata_registry import MetadataRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MetadataHook:
|
class MetadataHook:
|
||||||
"""Install hooks for metadata collection"""
|
"""Install hooks for metadata collection"""
|
||||||
|
|
||||||
@@ -23,7 +26,7 @@ class MetadataHook:
|
|||||||
|
|
||||||
# If we can't find the execution module, we can't install hooks
|
# If we can't find the execution module, we can't install hooks
|
||||||
if execution is None:
|
if execution is None:
|
||||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
logger.warning("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Detect whether we're using the new async version of ComfyUI
|
# Detect whether we're using the new async version of ComfyUI
|
||||||
@@ -37,16 +40,16 @@ class MetadataHook:
|
|||||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||||
|
|
||||||
if is_async:
|
if is_async:
|
||||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
logger.info("Detected async ComfyUI execution, installing async metadata hooks")
|
||||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||||
else:
|
else:
|
||||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
logger.info("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||||
MetadataHook._install_sync_hooks(execution)
|
MetadataHook._install_sync_hooks(execution)
|
||||||
|
|
||||||
print("Metadata collection hooks installed for runtime values")
|
logger.info("Metadata collection hooks installed for runtime values")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error installing metadata hooks: {str(e)}")
|
logger.error(f"Error installing metadata hooks: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _install_sync_hooks(execution):
|
def _install_sync_hooks(execution):
|
||||||
@@ -82,7 +85,7 @@ class MetadataHook:
|
|||||||
if node_id is not None:
|
if node_id is not None:
|
||||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||||
|
|
||||||
# Execute the original function
|
# Execute the original function
|
||||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||||
@@ -113,7 +116,7 @@ class MetadataHook:
|
|||||||
if node_id is not None:
|
if node_id is not None:
|
||||||
registry.update_node_execution(node_id, class_type, results)
|
registry.update_node_execution(node_id, class_type, results)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
logger.error(f"Error collecting metadata (post-execution): {str(e)}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -159,7 +162,7 @@ class MetadataHook:
|
|||||||
if node_id is not None:
|
if node_id is not None:
|
||||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||||
|
|
||||||
# Call original function with all args/kwargs
|
# Call original function with all args/kwargs
|
||||||
results = await original_map_node_over_list(
|
results = await original_map_node_over_list(
|
||||||
@@ -176,7 +179,7 @@ class MetadataHook:
|
|||||||
if node_id is not None:
|
if node_id is not None:
|
||||||
registry.update_node_execution(node_id, class_type, results)
|
registry.update_node_execution(node_id, class_type, results)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
logger.error(f"Error collecting metadata (post-execution): {str(e)}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -126,9 +126,7 @@ class LoraCyclerLM:
|
|||||||
"current_index": [clamped_index],
|
"current_index": [clamped_index],
|
||||||
"next_index": [next_index],
|
"next_index": [next_index],
|
||||||
"total_count": [total_count],
|
"total_count": [total_count],
|
||||||
"current_lora_name": [
|
"current_lora_name": [current_lora["file_name"]],
|
||||||
current_lora.get("model_name", current_lora["file_name"])
|
|
||||||
],
|
|
||||||
"current_lora_filename": [current_lora["file_name"]],
|
"current_lora_filename": [current_lora["file_name"]],
|
||||||
"next_lora_name": [next_display_name],
|
"next_lora_name": [next_display_name],
|
||||||
"next_lora_filename": [next_lora["file_name"]],
|
"next_lora_filename": [next_lora["file_name"]],
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from ..metadata_collector.metadata_processor import MetadataProcessor
|
|||||||
from ..metadata_collector import get_metadata
|
from ..metadata_collector import get_metadata
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
import piexif
|
import piexif
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class SaveImageLM:
|
class SaveImageLM:
|
||||||
NAME = "Save Image (LoraManager)"
|
NAME = "Save Image (LoraManager)"
|
||||||
@@ -385,7 +388,7 @@ class SaveImageLM:
|
|||||||
exif_bytes = piexif.dump(exif_dict)
|
exif_bytes = piexif.dump(exif_dict)
|
||||||
save_kwargs["exif"] = exif_bytes
|
save_kwargs["exif"] = exif_bytes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error adding EXIF data: {e}")
|
logger.error(f"Error adding EXIF data: {e}")
|
||||||
img.save(file_path, format="JPEG", **save_kwargs)
|
img.save(file_path, format="JPEG", **save_kwargs)
|
||||||
elif file_format == "webp":
|
elif file_format == "webp":
|
||||||
try:
|
try:
|
||||||
@@ -403,7 +406,7 @@ class SaveImageLM:
|
|||||||
exif_bytes = piexif.dump(exif_dict)
|
exif_bytes = piexif.dump(exif_dict)
|
||||||
save_kwargs["exif"] = exif_bytes
|
save_kwargs["exif"] = exif_bytes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error adding EXIF data: {e}")
|
logger.error(f"Error adding EXIF data: {e}")
|
||||||
|
|
||||||
img.save(file_path, format="WEBP", **save_kwargs)
|
img.save(file_path, format="WEBP", **save_kwargs)
|
||||||
|
|
||||||
@@ -414,7 +417,7 @@ class SaveImageLM:
|
|||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error saving image: {e}")
|
logger.error(f"Error saving image: {e}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||||
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
|
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
|
||||||
|
RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler:
|
|||||||
except ExampleImagesDownloadError as exc:
|
except ExampleImagesDownloadError as exc:
|
||||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
"""Lightweight check to see if any models need example images downloaded."""
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding'])
|
||||||
|
result = await self._download_manager.check_pending_models(model_types)
|
||||||
|
return web.json_response(result)
|
||||||
|
except Exception as exc:
|
||||||
|
return web.json_response(
|
||||||
|
{'success': False, 'error': str(exc)},
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExampleImagesManagementHandler:
|
class ExampleImagesManagementHandler:
|
||||||
"""HTTP adapters for import/delete endpoints."""
|
"""HTTP adapters for import/delete endpoints."""
|
||||||
@@ -161,6 +174,7 @@ class ExampleImagesHandlerSet:
|
|||||||
"resume_example_images": self.download.resume_example_images,
|
"resume_example_images": self.download.resume_example_images,
|
||||||
"stop_example_images": self.download.stop_example_images,
|
"stop_example_images": self.download.stop_example_images,
|
||||||
"force_download_example_images": self.download.force_download_example_images,
|
"force_download_example_images": self.download.force_download_example_images,
|
||||||
|
"check_example_images_needed": self.download.check_example_images_needed,
|
||||||
"import_example_images": self.management.import_example_images,
|
"import_example_images": self.management.import_example_images,
|
||||||
"delete_example_image": self.management.delete_example_image,
|
"delete_example_image": self.management.delete_example_image,
|
||||||
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,
|
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,
|
||||||
|
|||||||
259
py/services/cache_entry_validator.py
Normal file
259
py/services/cache_entry_validator.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
Cache Entry Validator
|
||||||
|
|
||||||
|
Validates and repairs cache entries to prevent runtime errors from
|
||||||
|
missing or invalid critical fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of validating a single cache entry."""
|
||||||
|
is_valid: bool
|
||||||
|
repaired: bool
|
||||||
|
errors: List[str] = field(default_factory=list)
|
||||||
|
entry: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEntryValidator:
|
||||||
|
"""
|
||||||
|
Validates and repairs cache entry core fields.
|
||||||
|
|
||||||
|
Critical fields that cause runtime errors when missing:
|
||||||
|
- file_path: KeyError in multiple locations
|
||||||
|
- sha256: KeyError/AttributeError in hash operations
|
||||||
|
|
||||||
|
Medium severity fields that may cause sorting/display issues:
|
||||||
|
- size: KeyError during sorting
|
||||||
|
- modified: KeyError during sorting
|
||||||
|
- model_name: AttributeError on .lower() calls
|
||||||
|
|
||||||
|
Low severity fields:
|
||||||
|
- tags: KeyError/TypeError in recipe operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Field definitions: (default_value, is_required)
|
||||||
|
CORE_FIELDS: Dict[str, Tuple[Any, bool]] = {
|
||||||
|
'file_path': ('', True),
|
||||||
|
'sha256': ('', True),
|
||||||
|
'file_name': ('', False),
|
||||||
|
'model_name': ('', False),
|
||||||
|
'folder': ('', False),
|
||||||
|
'size': (0, False),
|
||||||
|
'modified': (0.0, False),
|
||||||
|
'tags': ([], False),
|
||||||
|
'preview_url': ('', False),
|
||||||
|
'base_model': ('', False),
|
||||||
|
'from_civitai': (True, False),
|
||||||
|
'favorite': (False, False),
|
||||||
|
'exclude': (False, False),
|
||||||
|
'db_checked': (False, False),
|
||||||
|
'preview_nsfw_level': (0, False),
|
||||||
|
'notes': ('', False),
|
||||||
|
'usage_tips': ('', False),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, entry: Dict[str, Any], *, auto_repair: bool = True) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate a single cache entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry: The cache entry dictionary to validate
|
||||||
|
auto_repair: If True, attempt to repair missing/invalid fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status and optionally repaired entry
|
||||||
|
"""
|
||||||
|
if entry is None:
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
repaired=False,
|
||||||
|
errors=['Entry is None'],
|
||||||
|
entry=None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
repaired=False,
|
||||||
|
errors=[f'Entry is not a dict: {type(entry).__name__}'],
|
||||||
|
entry=None
|
||||||
|
)
|
||||||
|
|
||||||
|
errors: List[str] = []
|
||||||
|
repaired = False
|
||||||
|
working_entry = dict(entry) if auto_repair else entry
|
||||||
|
|
||||||
|
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
|
||||||
|
value = working_entry.get(field_name)
|
||||||
|
|
||||||
|
# Check if field is missing or None
|
||||||
|
if value is None:
|
||||||
|
if is_required:
|
||||||
|
errors.append(f"Required field '{field_name}' is missing or None")
|
||||||
|
if auto_repair:
|
||||||
|
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||||
|
repaired = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate field type and value
|
||||||
|
field_error = cls._validate_field(field_name, value, default_value)
|
||||||
|
if field_error:
|
||||||
|
errors.append(field_error)
|
||||||
|
if auto_repair:
|
||||||
|
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||||
|
repaired = True
|
||||||
|
|
||||||
|
# Special validation: file_path must not be empty for required field
|
||||||
|
file_path = working_entry.get('file_path', '')
|
||||||
|
if not file_path or (isinstance(file_path, str) and not file_path.strip()):
|
||||||
|
errors.append("Required field 'file_path' is empty")
|
||||||
|
# Cannot repair empty file_path - entry is invalid
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
repaired=repaired,
|
||||||
|
errors=errors,
|
||||||
|
entry=working_entry if auto_repair else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Special validation: sha256 must not be empty for required field
|
||||||
|
sha256 = working_entry.get('sha256', '')
|
||||||
|
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
|
||||||
|
errors.append("Required field 'sha256' is empty")
|
||||||
|
# Cannot repair empty sha256 - entry is invalid
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=False,
|
||||||
|
repaired=repaired,
|
||||||
|
errors=errors,
|
||||||
|
entry=working_entry if auto_repair else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize sha256 to lowercase if needed
|
||||||
|
if isinstance(sha256, str):
|
||||||
|
normalized_sha = sha256.lower().strip()
|
||||||
|
if normalized_sha != sha256:
|
||||||
|
working_entry['sha256'] = normalized_sha
|
||||||
|
repaired = True
|
||||||
|
|
||||||
|
# Determine if entry is valid
|
||||||
|
# Entry is valid if no critical required field errors remain after repair
|
||||||
|
# Critical fields are file_path and sha256
|
||||||
|
CRITICAL_REQUIRED_FIELDS = {'file_path', 'sha256'}
|
||||||
|
has_critical_errors = any(
|
||||||
|
"Required field" in error and
|
||||||
|
any(f"'{field}'" in error for field in CRITICAL_REQUIRED_FIELDS)
|
||||||
|
for error in errors
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = not has_critical_errors
|
||||||
|
|
||||||
|
return ValidationResult(
|
||||||
|
is_valid=is_valid,
|
||||||
|
repaired=repaired,
|
||||||
|
errors=errors,
|
||||||
|
entry=working_entry if auto_repair else entry
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_batch(
|
||||||
|
cls,
|
||||||
|
entries: List[Dict[str, Any]],
|
||||||
|
*,
|
||||||
|
auto_repair: bool = True
|
||||||
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Validate a batch of cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entries: List of cache entry dictionaries to validate
|
||||||
|
auto_repair: If True, attempt to repair missing/invalid fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (valid_entries, invalid_entries)
|
||||||
|
"""
|
||||||
|
if not entries:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
valid_entries: List[Dict[str, Any]] = []
|
||||||
|
invalid_entries: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
result = cls.validate(entry, auto_repair=auto_repair)
|
||||||
|
|
||||||
|
if result.is_valid:
|
||||||
|
# Use repaired entry if available, otherwise original
|
||||||
|
valid_entries.append(result.entry if result.entry else entry)
|
||||||
|
else:
|
||||||
|
invalid_entries.append(entry)
|
||||||
|
# Log invalid entries for debugging
|
||||||
|
file_path = entry.get('file_path', '<unknown>') if isinstance(entry, dict) else '<not a dict>'
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid cache entry for '{file_path}': {', '.join(result.errors)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_entries, invalid_entries
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_field(cls, field_name: str, value: Any, default_value: Any) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Validate a specific field value.
|
||||||
|
|
||||||
|
Returns an error message if invalid, None if valid.
|
||||||
|
"""
|
||||||
|
expected_type = type(default_value)
|
||||||
|
|
||||||
|
# Special handling for numeric types
|
||||||
|
if expected_type == int:
|
||||||
|
if not isinstance(value, (int, float)):
|
||||||
|
return f"Field '{field_name}' should be numeric, got {type(value).__name__}"
|
||||||
|
elif expected_type == float:
|
||||||
|
if not isinstance(value, (int, float)):
|
||||||
|
return f"Field '{field_name}' should be numeric, got {type(value).__name__}"
|
||||||
|
elif expected_type == bool:
|
||||||
|
# Be lenient with boolean fields - accept truthy/falsy values
|
||||||
|
pass
|
||||||
|
elif expected_type == str:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return f"Field '{field_name}' should be string, got {type(value).__name__}"
|
||||||
|
elif expected_type == list:
|
||||||
|
if not isinstance(value, (list, tuple)):
|
||||||
|
return f"Field '{field_name}' should be list, got {type(value).__name__}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_default_copy(cls, default_value: Any) -> Any:
|
||||||
|
"""Get a copy of the default value to avoid shared mutable state."""
|
||||||
|
if isinstance(default_value, list):
|
||||||
|
return list(default_value)
|
||||||
|
if isinstance(default_value, dict):
|
||||||
|
return dict(default_value)
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_file_path_safe(cls, entry: Dict[str, Any], default: str = '') -> str:
|
||||||
|
"""Safely get file_path from an entry."""
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
return default
|
||||||
|
value = entry.get('file_path')
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
return default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_sha256_safe(cls, entry: Dict[str, Any], default: str = '') -> str:
|
||||||
|
"""Safely get sha256 from an entry."""
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
return default
|
||||||
|
value = entry.get('sha256')
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.lower()
|
||||||
|
return default
|
||||||
201
py/services/cache_health_monitor.py
Normal file
201
py/services/cache_health_monitor.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
Cache Health Monitor
|
||||||
|
|
||||||
|
Monitors cache health status and determines when user intervention is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .cache_entry_validator import CacheEntryValidator, ValidationResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheHealthStatus(Enum):
|
||||||
|
"""Health status of the cache."""
|
||||||
|
HEALTHY = "healthy"
|
||||||
|
DEGRADED = "degraded"
|
||||||
|
CORRUPTED = "corrupted"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HealthReport:
|
||||||
|
"""Report of cache health check."""
|
||||||
|
status: CacheHealthStatus
|
||||||
|
total_entries: int
|
||||||
|
valid_entries: int
|
||||||
|
invalid_entries: int
|
||||||
|
repaired_entries: int
|
||||||
|
invalid_paths: List[str] = field(default_factory=list)
|
||||||
|
message: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def corruption_rate(self) -> float:
|
||||||
|
"""Calculate the percentage of invalid entries."""
|
||||||
|
if self.total_entries <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.invalid_entries / self.total_entries
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
return {
|
||||||
|
'status': self.status.value,
|
||||||
|
'total_entries': self.total_entries,
|
||||||
|
'valid_entries': self.valid_entries,
|
||||||
|
'invalid_entries': self.invalid_entries,
|
||||||
|
'repaired_entries': self.repaired_entries,
|
||||||
|
'corruption_rate': f"{self.corruption_rate:.1%}",
|
||||||
|
'invalid_paths': self.invalid_paths[:10], # Limit to first 10
|
||||||
|
'message': self.message,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CacheHealthMonitor:
|
||||||
|
"""
|
||||||
|
Monitors cache health and determines appropriate status.
|
||||||
|
|
||||||
|
Thresholds:
|
||||||
|
- HEALTHY: 0% invalid entries
|
||||||
|
- DEGRADED: 0-5% invalid entries (auto-repaired, user should rebuild)
|
||||||
|
- CORRUPTED: >5% invalid entries (significant data loss likely)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Threshold percentages
|
||||||
|
DEGRADED_THRESHOLD = 0.01 # 1% - show warning
|
||||||
|
CORRUPTED_THRESHOLD = 0.05 # 5% - critical warning
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
degraded_threshold: float = DEGRADED_THRESHOLD,
|
||||||
|
corrupted_threshold: float = CORRUPTED_THRESHOLD
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the health monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
degraded_threshold: Corruption rate threshold for DEGRADED status
|
||||||
|
corrupted_threshold: Corruption rate threshold for CORRUPTED status
|
||||||
|
"""
|
||||||
|
self.degraded_threshold = degraded_threshold
|
||||||
|
self.corrupted_threshold = corrupted_threshold
|
||||||
|
|
||||||
|
def check_health(
|
||||||
|
self,
|
||||||
|
entries: List[Dict[str, Any]],
|
||||||
|
*,
|
||||||
|
auto_repair: bool = True
|
||||||
|
) -> HealthReport:
|
||||||
|
"""
|
||||||
|
Check the health of cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entries: List of cache entry dictionaries to check
|
||||||
|
auto_repair: If True, attempt to repair entries during validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthReport with status and statistics
|
||||||
|
"""
|
||||||
|
if not entries:
|
||||||
|
return HealthReport(
|
||||||
|
status=CacheHealthStatus.HEALTHY,
|
||||||
|
total_entries=0,
|
||||||
|
valid_entries=0,
|
||||||
|
invalid_entries=0,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_entries = len(entries)
|
||||||
|
valid_entries: List[Dict[str, Any]] = []
|
||||||
|
invalid_entries: List[Dict[str, Any]] = []
|
||||||
|
repaired_count = 0
|
||||||
|
invalid_paths: List[str] = []
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=auto_repair)
|
||||||
|
|
||||||
|
if result.is_valid:
|
||||||
|
valid_entries.append(result.entry if result.entry else entry)
|
||||||
|
if result.repaired:
|
||||||
|
repaired_count += 1
|
||||||
|
else:
|
||||||
|
invalid_entries.append(entry)
|
||||||
|
# Extract file path for reporting
|
||||||
|
file_path = CacheEntryValidator.get_file_path_safe(entry, '<unknown>')
|
||||||
|
invalid_paths.append(file_path)
|
||||||
|
|
||||||
|
invalid_count = len(invalid_entries)
|
||||||
|
valid_count = len(valid_entries)
|
||||||
|
|
||||||
|
# Determine status based on corruption rate
|
||||||
|
corruption_rate = invalid_count / total_entries if total_entries > 0 else 0.0
|
||||||
|
|
||||||
|
if invalid_count == 0:
|
||||||
|
status = CacheHealthStatus.HEALTHY
|
||||||
|
message = "Cache is healthy"
|
||||||
|
elif corruption_rate >= self.corrupted_threshold:
|
||||||
|
status = CacheHealthStatus.CORRUPTED
|
||||||
|
message = (
|
||||||
|
f"Cache is corrupted: {invalid_count} invalid entries "
|
||||||
|
f"({corruption_rate:.1%}). Rebuild recommended."
|
||||||
|
)
|
||||||
|
elif corruption_rate >= self.degraded_threshold or invalid_count > 0:
|
||||||
|
status = CacheHealthStatus.DEGRADED
|
||||||
|
message = (
|
||||||
|
f"Cache has {invalid_count} invalid entries "
|
||||||
|
f"({corruption_rate:.1%}). Consider rebuilding cache."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This shouldn't happen, but handle gracefully
|
||||||
|
status = CacheHealthStatus.HEALTHY
|
||||||
|
message = "Cache is healthy"
|
||||||
|
|
||||||
|
# Log the health check result
|
||||||
|
if status != CacheHealthStatus.HEALTHY:
|
||||||
|
logger.warning(
|
||||||
|
f"Cache health check: {status.value} - "
|
||||||
|
f"{invalid_count}/{total_entries} invalid, "
|
||||||
|
f"{repaired_count} repaired"
|
||||||
|
)
|
||||||
|
if invalid_paths:
|
||||||
|
logger.debug(f"Invalid entry paths: {invalid_paths[:5]}")
|
||||||
|
|
||||||
|
return HealthReport(
|
||||||
|
status=status,
|
||||||
|
total_entries=total_entries,
|
||||||
|
valid_entries=valid_count,
|
||||||
|
invalid_entries=invalid_count,
|
||||||
|
repaired_entries=repaired_count,
|
||||||
|
invalid_paths=invalid_paths,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_notify_user(self, report: HealthReport) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the user should be notified about cache health.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
report: The health report to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user should be notified
|
||||||
|
"""
|
||||||
|
return report.status != CacheHealthStatus.HEALTHY
|
||||||
|
|
||||||
|
def get_notification_severity(self, report: HealthReport) -> str:
|
||||||
|
"""
|
||||||
|
Get the severity level for user notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
report: The health report to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Severity string: 'warning' or 'error'
|
||||||
|
"""
|
||||||
|
if report.status == CacheHealthStatus.CORRUPTED:
|
||||||
|
return 'error'
|
||||||
|
return 'warning'
|
||||||
@@ -30,36 +30,36 @@ class LoraScanner(ModelScanner):
|
|||||||
|
|
||||||
async def diagnose_hash_index(self):
|
async def diagnose_hash_index(self):
|
||||||
"""Diagnostic method to verify hash index functionality"""
|
"""Diagnostic method to verify hash index functionality"""
|
||||||
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
logger.debug("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n")
|
||||||
|
|
||||||
# First check if the hash index has any entries
|
# First check if the hash index has any entries
|
||||||
if hasattr(self, '_hash_index'):
|
if hasattr(self, '_hash_index'):
|
||||||
index_entries = len(self._hash_index._hash_to_path)
|
index_entries = len(self._hash_index._hash_to_path)
|
||||||
print(f"Hash index has {index_entries} entries", file=sys.stderr)
|
logger.debug(f"Hash index has {index_entries} entries")
|
||||||
|
|
||||||
# Print a few example entries if available
|
# Print a few example entries if available
|
||||||
if index_entries > 0:
|
if index_entries > 0:
|
||||||
print("\nSample hash index entries:", file=sys.stderr)
|
logger.debug("\nSample hash index entries:")
|
||||||
count = 0
|
count = 0
|
||||||
for hash_val, path in self._hash_index._hash_to_path.items():
|
for hash_val, path in self._hash_index._hash_to_path.items():
|
||||||
if count < 5: # Just show the first 5
|
if count < 5: # Just show the first 5
|
||||||
print(f"Hash: {hash_val[:8]}... -> Path: {path}", file=sys.stderr)
|
logger.debug(f"Hash: {hash_val[:8]}... -> Path: {path}")
|
||||||
count += 1
|
count += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print("Hash index not initialized", file=sys.stderr)
|
logger.debug("Hash index not initialized")
|
||||||
|
|
||||||
# Try looking up by a known hash for testing
|
# Try looking up by a known hash for testing
|
||||||
if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path:
|
if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path:
|
||||||
print("No hash entries to test lookup with", file=sys.stderr)
|
logger.debug("No hash entries to test lookup with")
|
||||||
return
|
return
|
||||||
|
|
||||||
test_hash = next(iter(self._hash_index._hash_to_path.keys()))
|
test_hash = next(iter(self._hash_index._hash_to_path.keys()))
|
||||||
test_path = self._hash_index.get_path(test_hash)
|
test_path = self._hash_index.get_path(test_hash)
|
||||||
print(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}", file=sys.stderr)
|
logger.debug(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}")
|
||||||
|
|
||||||
# Also test reverse lookup
|
# Also test reverse lookup
|
||||||
test_hash_result = self._hash_index.get_hash(test_path)
|
test_hash_result = self._hash_index.get_hash(test_path)
|
||||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
logger.debug(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n")
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from operator import itemgetter
|
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
|
|
||||||
# Supported sort modes: (sort_key, order)
|
# Supported sort modes: (sort_key, order)
|
||||||
@@ -229,17 +228,17 @@ class ModelCache:
|
|||||||
reverse=reverse
|
reverse=reverse
|
||||||
)
|
)
|
||||||
elif sort_key == 'date':
|
elif sort_key == 'date':
|
||||||
# Sort by modified timestamp
|
# Sort by modified timestamp (use .get() with default to handle missing fields)
|
||||||
result = sorted(
|
result = sorted(
|
||||||
data,
|
data,
|
||||||
key=itemgetter('modified'),
|
key=lambda x: x.get('modified', 0.0),
|
||||||
reverse=reverse
|
reverse=reverse
|
||||||
)
|
)
|
||||||
elif sort_key == 'size':
|
elif sort_key == 'size':
|
||||||
# Sort by file size
|
# Sort by file size (use .get() with default to handle missing fields)
|
||||||
result = sorted(
|
result = sorted(
|
||||||
data,
|
data,
|
||||||
key=itemgetter('size'),
|
key=lambda x: x.get('size', 0),
|
||||||
reverse=reverse
|
reverse=reverse
|
||||||
)
|
)
|
||||||
elif sort_key == 'usage':
|
elif sort_key == 'usage':
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from .service_registry import ServiceRegistry
|
|||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
from .persistent_model_cache import get_persistent_cache
|
from .persistent_model_cache import get_persistent_cache
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
|
from .cache_entry_validator import CacheEntryValidator
|
||||||
|
from .cache_health_monitor import CacheHealthMonitor, CacheHealthStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -468,6 +470,39 @@ class ModelScanner:
|
|||||||
for tag in adjusted_item.get('tags') or []:
|
for tag in adjusted_item.get('tags') or []:
|
||||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
# Validate cache entries and check health
|
||||||
|
valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
|
||||||
|
adjusted_raw_data, auto_repair=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if invalid_entries:
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
report = monitor.check_health(adjusted_raw_data, auto_repair=True)
|
||||||
|
|
||||||
|
if report.status != CacheHealthStatus.HEALTHY:
|
||||||
|
# Broadcast health warning to frontend
|
||||||
|
await ws_manager.broadcast_cache_health_warning(report, page_type)
|
||||||
|
logger.warning(
|
||||||
|
f"{self.model_type.capitalize()} Scanner: Cache health issue detected - "
|
||||||
|
f"{report.invalid_entries} invalid entries, {report.repaired_entries} repaired"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use only valid entries
|
||||||
|
adjusted_raw_data = valid_entries
|
||||||
|
|
||||||
|
# Rebuild tags count from valid entries only
|
||||||
|
tags_count = {}
|
||||||
|
for item in adjusted_raw_data:
|
||||||
|
for tag in item.get('tags') or []:
|
||||||
|
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
# Remove invalid entries from hash index
|
||||||
|
for invalid_entry in invalid_entries:
|
||||||
|
file_path = CacheEntryValidator.get_file_path_safe(invalid_entry)
|
||||||
|
sha256 = CacheEntryValidator.get_sha256_safe(invalid_entry)
|
||||||
|
if file_path:
|
||||||
|
hash_index.remove_by_path(file_path, sha256)
|
||||||
|
|
||||||
scan_result = CacheBuildResult(
|
scan_result = CacheBuildResult(
|
||||||
raw_data=adjusted_raw_data,
|
raw_data=adjusted_raw_data,
|
||||||
hash_index=hash_index,
|
hash_index=hash_index,
|
||||||
@@ -651,7 +686,6 @@ class ModelScanner:
|
|||||||
|
|
||||||
async def _initialize_cache(self) -> None:
|
async def _initialize_cache(self) -> None:
|
||||||
"""Initialize or refresh the cache"""
|
"""Initialize or refresh the cache"""
|
||||||
print("init start", flush=True)
|
|
||||||
self._is_initializing = True # Set flag
|
self._is_initializing = True # Set flag
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -665,7 +699,6 @@ class ModelScanner:
|
|||||||
scan_result = await self._gather_model_data()
|
scan_result = await self._gather_model_data()
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
await self._save_persistent_cache(scan_result)
|
await self._save_persistent_cache(scan_result)
|
||||||
print("init end", flush=True)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||||
@@ -776,6 +809,18 @@ class ModelScanner:
|
|||||||
model_data = self.adjust_cached_entry(dict(model_data))
|
model_data = self.adjust_cached_entry(dict(model_data))
|
||||||
if not model_data:
|
if not model_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Validate the new entry before adding
|
||||||
|
validation_result = CacheEntryValidator.validate(
|
||||||
|
model_data, auto_repair=True
|
||||||
|
)
|
||||||
|
if not validation_result.is_valid:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping invalid entry during reconcile: {path}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
model_data = validation_result.entry
|
||||||
|
|
||||||
self._ensure_license_flags(model_data)
|
self._ensure_license_flags(model_data)
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(model_data)
|
self._cache.raw_data.append(model_data)
|
||||||
@@ -1090,6 +1135,17 @@ class ModelScanner:
|
|||||||
processed_files += 1
|
processed_files += 1
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
# Validate the entry before adding
|
||||||
|
validation_result = CacheEntryValidator.validate(
|
||||||
|
result, auto_repair=True
|
||||||
|
)
|
||||||
|
if not validation_result.is_valid:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping invalid scan result: {file_path}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
result = validation_result.entry
|
||||||
|
|
||||||
self._ensure_license_flags(result)
|
self._ensure_license_flags(result)
|
||||||
raw_data.append(result)
|
raw_data.append(result)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
from .recipe_cache import RecipeCache
|
from .recipe_cache import RecipeCache
|
||||||
from .recipe_fts_index import RecipeFTSIndex
|
from .recipe_fts_index import RecipeFTSIndex
|
||||||
from .persistent_recipe_cache import PersistentRecipeCache, get_persistent_recipe_cache
|
from .persistent_recipe_cache import PersistentRecipeCache, get_persistent_recipe_cache, PersistedRecipeData
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
from .metadata_service import get_default_metadata_provider
|
from .metadata_service import get_default_metadata_provider
|
||||||
@@ -431,6 +431,16 @@ class RecipeScanner:
|
|||||||
4. Persist results for next startup
|
4. Persist results for next startup
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Ensure cache exists to avoid None reference errors
|
||||||
|
if self._cache is None:
|
||||||
|
self._cache = RecipeCache(
|
||||||
|
raw_data=[],
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[],
|
||||||
|
folders=[],
|
||||||
|
folder_tree={},
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new event loop for this thread
|
# Create a new event loop for this thread
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
@@ -492,7 +502,7 @@ class RecipeScanner:
|
|||||||
|
|
||||||
def _reconcile_recipe_cache(
|
def _reconcile_recipe_cache(
|
||||||
self,
|
self,
|
||||||
persisted: "PersistedRecipeData",
|
persisted: PersistedRecipeData,
|
||||||
recipes_dir: str,
|
recipes_dir: str,
|
||||||
) -> Tuple[List[Dict], bool, Dict[str, str]]:
|
) -> Tuple[List[Dict], bool, Dict[str, str]]:
|
||||||
"""Reconcile persisted cache with current filesystem state.
|
"""Reconcile persisted cache with current filesystem state.
|
||||||
@@ -504,8 +514,6 @@ class RecipeScanner:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (recipes list, changed flag, json_paths dict).
|
Tuple of (recipes list, changed flag, json_paths dict).
|
||||||
"""
|
"""
|
||||||
from .persistent_recipe_cache import PersistedRecipeData
|
|
||||||
|
|
||||||
recipes: List[Dict] = []
|
recipes: List[Dict] = []
|
||||||
json_paths: Dict[str, str] = {}
|
json_paths: Dict[str, str] = {}
|
||||||
changed = False
|
changed = False
|
||||||
@@ -522,32 +530,37 @@ class RecipeScanner:
|
|||||||
except OSError:
|
except OSError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Build lookup of persisted recipes by json_path
|
# Build recipe_id -> recipe lookup (O(n) instead of O(n²))
|
||||||
persisted_by_path: Dict[str, Dict] = {}
|
recipe_by_id: Dict[str, Dict] = {
|
||||||
for recipe in persisted.raw_data:
|
|
||||||
recipe_id = str(recipe.get('id', ''))
|
|
||||||
if recipe_id:
|
|
||||||
# Find the json_path from file_stats
|
|
||||||
for json_path, (mtime, size) in persisted.file_stats.items():
|
|
||||||
if os.path.basename(json_path).startswith(recipe_id):
|
|
||||||
persisted_by_path[json_path] = recipe
|
|
||||||
break
|
|
||||||
|
|
||||||
# Also index by recipe ID for faster lookups
|
|
||||||
persisted_by_id: Dict[str, Dict] = {
|
|
||||||
str(r.get('id', '')): r for r in persisted.raw_data if r.get('id')
|
str(r.get('id', '')): r for r in persisted.raw_data if r.get('id')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Build json_path -> recipe lookup from file_stats (O(m))
|
||||||
|
persisted_by_path: Dict[str, Dict] = {}
|
||||||
|
for json_path in persisted.file_stats.keys():
|
||||||
|
basename = os.path.basename(json_path)
|
||||||
|
if basename.lower().endswith('.recipe.json'):
|
||||||
|
recipe_id = basename[:-len('.recipe.json')]
|
||||||
|
if recipe_id in recipe_by_id:
|
||||||
|
persisted_by_path[json_path] = recipe_by_id[recipe_id]
|
||||||
|
|
||||||
# Process current files
|
# Process current files
|
||||||
for file_path, (current_mtime, current_size) in current_files.items():
|
for file_path, (current_mtime, current_size) in current_files.items():
|
||||||
cached_stats = persisted.file_stats.get(file_path)
|
cached_stats = persisted.file_stats.get(file_path)
|
||||||
|
|
||||||
|
# Extract recipe_id from current file for fallback lookup
|
||||||
|
basename = os.path.basename(file_path)
|
||||||
|
recipe_id_from_file = basename[:-len('.recipe.json')] if basename.lower().endswith('.recipe.json') else None
|
||||||
|
|
||||||
if cached_stats:
|
if cached_stats:
|
||||||
cached_mtime, cached_size = cached_stats
|
cached_mtime, cached_size = cached_stats
|
||||||
# Check if file is unchanged
|
# Check if file is unchanged
|
||||||
if abs(current_mtime - cached_mtime) < 1.0 and current_size == cached_size:
|
if abs(current_mtime - cached_mtime) < 1.0 and current_size == cached_size:
|
||||||
# Use cached data
|
# Try direct path lookup first
|
||||||
cached_recipe = persisted_by_path.get(file_path)
|
cached_recipe = persisted_by_path.get(file_path)
|
||||||
|
# Fallback to recipe_id lookup if path lookup fails
|
||||||
|
if not cached_recipe and recipe_id_from_file:
|
||||||
|
cached_recipe = recipe_by_id.get(recipe_id_from_file)
|
||||||
if cached_recipe:
|
if cached_recipe:
|
||||||
recipe_id = str(cached_recipe.get('id', ''))
|
recipe_id = str(cached_recipe.get('id', ''))
|
||||||
# Track folder from file path
|
# Track folder from file path
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
|
|||||||
"compact_mode": False,
|
"compact_mode": False,
|
||||||
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
|
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
|
||||||
"model_name_display": "model_name",
|
"model_name_display": "model_name",
|
||||||
"model_card_footer_action": "example_images",
|
"model_card_footer_action": "replace_preview",
|
||||||
"update_flag_strategy": "same_base",
|
"update_flag_strategy": "same_base",
|
||||||
"auto_organize_exclusions": [],
|
"auto_organize_exclusions": [],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,9 +48,14 @@ class BulkMetadataRefreshUseCase:
|
|||||||
for model in cache.raw_data
|
for model in cache.raw_data
|
||||||
if model.get("sha256")
|
if model.get("sha256")
|
||||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||||
and (
|
and not (
|
||||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
# Skip models confirmed not on CivitAI when no need to retry
|
||||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
model.get("from_civitai") is False
|
||||||
|
and model.get("civitai_deleted") is True
|
||||||
|
and (
|
||||||
|
not enable_metadata_archive_db
|
||||||
|
or model.get("db_checked", False)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -255,6 +255,42 @@ class WebSocketManager:
|
|||||||
self._download_progress.pop(download_id, None)
|
self._download_progress.pop(download_id, None)
|
||||||
logger.debug(f"Cleaned up old download progress for {download_id}")
|
logger.debug(f"Cleaned up old download progress for {download_id}")
|
||||||
|
|
||||||
|
async def broadcast_cache_health_warning(self, report: 'HealthReport', page_type: str = None):
|
||||||
|
"""
|
||||||
|
Broadcast cache health warning to frontend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
report: HealthReport instance from CacheHealthMonitor
|
||||||
|
page_type: The page type (loras, checkpoints, embeddings)
|
||||||
|
"""
|
||||||
|
from .cache_health_monitor import CacheHealthStatus
|
||||||
|
|
||||||
|
# Only broadcast if there are issues
|
||||||
|
if report.status == CacheHealthStatus.HEALTHY:
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'type': 'cache_health_warning',
|
||||||
|
'status': report.status.value,
|
||||||
|
'message': report.message,
|
||||||
|
'pageType': page_type,
|
||||||
|
'details': {
|
||||||
|
'total': report.total_entries,
|
||||||
|
'valid': report.valid_entries,
|
||||||
|
'invalid': report.invalid_entries,
|
||||||
|
'repaired': report.repaired_entries,
|
||||||
|
'corruption_rate': f"{report.corruption_rate:.1%}",
|
||||||
|
'invalid_paths': report.invalid_paths[:5], # Limit to first 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Broadcasting cache health warning: {report.status.value} "
|
||||||
|
f"({report.invalid_entries} invalid entries)"
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.broadcast(payload)
|
||||||
|
|
||||||
def get_connected_clients_count(self) -> int:
|
def get_connected_clients_count(self) -> int:
|
||||||
"""Get number of connected clients"""
|
"""Get number of connected clients"""
|
||||||
return len(self._websockets)
|
return len(self._websockets)
|
||||||
|
|||||||
@@ -216,6 +216,11 @@ class DownloadManager:
|
|||||||
self._progress["failed_models"] = set()
|
self._progress["failed_models"] = set()
|
||||||
|
|
||||||
self._is_downloading = True
|
self._is_downloading = True
|
||||||
|
snapshot = self._progress.snapshot()
|
||||||
|
|
||||||
|
# Create the download task without awaiting it
|
||||||
|
# This ensures the HTTP response is returned immediately
|
||||||
|
# while the actual processing happens in the background
|
||||||
self._download_task = asyncio.create_task(
|
self._download_task = asyncio.create_task(
|
||||||
self._download_all_example_images(
|
self._download_all_example_images(
|
||||||
output_dir,
|
output_dir,
|
||||||
@@ -227,7 +232,10 @@ class DownloadManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
snapshot = self._progress.snapshot()
|
# Add a callback to handle task completion/errors
|
||||||
|
self._download_task.add_done_callback(
|
||||||
|
lambda t: self._handle_download_task_done(t, output_dir)
|
||||||
|
)
|
||||||
except ExampleImagesDownloadError:
|
except ExampleImagesDownloadError:
|
||||||
# Re-raise our own exception types without wrapping
|
# Re-raise our own exception types without wrapping
|
||||||
self._is_downloading = False
|
self._is_downloading = False
|
||||||
@@ -241,10 +249,25 @@ class DownloadManager:
|
|||||||
)
|
)
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
await self._broadcast_progress(status="running")
|
# Broadcast progress in the background without blocking the response
|
||||||
|
# This ensures the HTTP response is returned immediately
|
||||||
|
asyncio.create_task(self._broadcast_progress(status="running"))
|
||||||
|
|
||||||
return {"success": True, "message": "Download started", "status": snapshot}
|
return {"success": True, "message": "Download started", "status": snapshot}
|
||||||
|
|
||||||
|
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
|
||||||
|
"""Handle download task completion, including saving progress on error."""
|
||||||
|
try:
|
||||||
|
# This will re-raise any exception from the task
|
||||||
|
task.result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Download task failed with error: {e}", exc_info=True)
|
||||||
|
# Ensure progress is saved even on failure
|
||||||
|
try:
|
||||||
|
self._save_progress(output_dir)
|
||||||
|
except Exception as save_error:
|
||||||
|
logger.error(f"Failed to save progress after task failure: {save_error}")
|
||||||
|
|
||||||
async def get_status(self, request):
|
async def get_status(self, request):
|
||||||
"""Get the current status of example images download."""
|
"""Get the current status of example images download."""
|
||||||
|
|
||||||
@@ -254,6 +277,130 @@ class DownloadManager:
|
|||||||
"status": self._progress.snapshot(),
|
"status": self._progress.snapshot(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def check_pending_models(self, model_types: list[str]) -> dict:
|
||||||
|
"""Quickly check how many models need example images downloaded.
|
||||||
|
|
||||||
|
This is a lightweight check that avoids the overhead of starting
|
||||||
|
a full download task when no work is needed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- total_models: Total number of models across specified types
|
||||||
|
- pending_count: Number of models needing example images
|
||||||
|
- processed_count: Number of already processed models
|
||||||
|
- failed_count: Number of models marked as failed
|
||||||
|
- needs_download: True if there are pending models to process
|
||||||
|
"""
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
if self._is_downloading:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": True,
|
||||||
|
"total_models": 0,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
"message": "Download already in progress",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get scanners
|
||||||
|
scanners = []
|
||||||
|
if "lora" in model_types:
|
||||||
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
scanners.append(("lora", lora_scanner))
|
||||||
|
|
||||||
|
if "checkpoint" in model_types:
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
scanners.append(("checkpoint", checkpoint_scanner))
|
||||||
|
|
||||||
|
if "embedding" in model_types:
|
||||||
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
|
scanners.append(("embedding", embedding_scanner))
|
||||||
|
|
||||||
|
# Load progress file to check processed models
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
active_library = settings_manager.get_active_library_name()
|
||||||
|
output_dir = self._resolve_output_dir(active_library)
|
||||||
|
|
||||||
|
processed_models: set[str] = set()
|
||||||
|
failed_models: set[str] = set()
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||||
|
if os.path.exists(progress_file):
|
||||||
|
try:
|
||||||
|
with open(progress_file, "r", encoding="utf-8") as f:
|
||||||
|
saved_progress = json.load(f)
|
||||||
|
processed_models = set(saved_progress.get("processed_models", []))
|
||||||
|
failed_models = set(saved_progress.get("failed_models", []))
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore progress file errors for quick check
|
||||||
|
|
||||||
|
# Count models
|
||||||
|
total_models = 0
|
||||||
|
models_with_hash = 0
|
||||||
|
|
||||||
|
for scanner_type, scanner in scanners:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
if cache and cache.raw_data:
|
||||||
|
for model in cache.raw_data:
|
||||||
|
total_models += 1
|
||||||
|
if model.get("sha256"):
|
||||||
|
models_with_hash += 1
|
||||||
|
|
||||||
|
# Calculate pending count
|
||||||
|
# A model is pending if it has a hash and is not in processed_models
|
||||||
|
# We also exclude failed_models unless force mode would be used
|
||||||
|
pending_count = models_with_hash - len(processed_models.intersection(
|
||||||
|
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
|
||||||
|
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
|
||||||
|
))
|
||||||
|
|
||||||
|
# More accurate pending count: check which models actually need processing
|
||||||
|
pending_hashes = set()
|
||||||
|
for scanner_type, scanner in scanners:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
if cache and cache.raw_data:
|
||||||
|
for model in cache.raw_data:
|
||||||
|
raw_hash = model.get("sha256")
|
||||||
|
if not raw_hash:
|
||||||
|
continue
|
||||||
|
model_hash = raw_hash.lower()
|
||||||
|
if model_hash not in processed_models:
|
||||||
|
# Check if model folder exists with files
|
||||||
|
model_dir = ExampleImagePathResolver.get_model_folder(
|
||||||
|
model_hash, active_library
|
||||||
|
)
|
||||||
|
if not _model_directory_has_files(model_dir):
|
||||||
|
pending_hashes.add(model_hash)
|
||||||
|
|
||||||
|
pending_count = len(pending_hashes)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": total_models,
|
||||||
|
"pending_count": pending_count,
|
||||||
|
"processed_count": len(processed_models),
|
||||||
|
"failed_count": len(failed_models),
|
||||||
|
"needs_download": pending_count > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking pending models: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"total_models": 0,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
}
|
||||||
|
|
||||||
async def pause_download(self, request):
|
async def pause_download(self, request):
|
||||||
"""Pause the example images download."""
|
"""Pause the example images download."""
|
||||||
|
|
||||||
|
|||||||
@@ -43,8 +43,15 @@ class ExampleImagesProcessor:
|
|||||||
return media_url
|
return media_url
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_file_extension_from_content_or_headers(content, headers, fallback_url=None):
|
def _get_file_extension_from_content_or_headers(content, headers, fallback_url=None, media_type_hint=None):
|
||||||
"""Determine file extension from content magic bytes or headers"""
|
"""Determine file extension from content magic bytes or headers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content bytes
|
||||||
|
headers: HTTP response headers
|
||||||
|
fallback_url: Original URL for extension extraction
|
||||||
|
media_type_hint: Optional media type hint from metadata (e.g., "video" or "image")
|
||||||
|
"""
|
||||||
# Check magic bytes for common formats
|
# Check magic bytes for common formats
|
||||||
if content:
|
if content:
|
||||||
if content.startswith(b'\xFF\xD8\xFF'):
|
if content.startswith(b'\xFF\xD8\xFF'):
|
||||||
@@ -82,6 +89,10 @@ class ExampleImagesProcessor:
|
|||||||
if ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or ext in SUPPORTED_MEDIA_EXTENSIONS['videos']:
|
if ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or ext in SUPPORTED_MEDIA_EXTENSIONS['videos']:
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
|
# Use media type hint from metadata if available
|
||||||
|
if media_type_hint == "video":
|
||||||
|
return '.mp4'
|
||||||
|
|
||||||
# Default fallback
|
# Default fallback
|
||||||
return '.jpg'
|
return '.jpg'
|
||||||
|
|
||||||
@@ -136,7 +147,7 @@ class ExampleImagesProcessor:
|
|||||||
if success:
|
if success:
|
||||||
# Determine file extension from content or headers
|
# Determine file extension from content or headers
|
||||||
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||||
content, headers, original_url
|
content, headers, original_url, image.get("type")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the detected file type is supported
|
# Check if the detected file type is supported
|
||||||
@@ -219,7 +230,7 @@ class ExampleImagesProcessor:
|
|||||||
if success:
|
if success:
|
||||||
# Determine file extension from content or headers
|
# Determine file extension from content or headers
|
||||||
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||||
content, headers, original_url
|
content, headers, original_url, image.get("type")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the detected file type is supported
|
# Check if the detected file type is supported
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ async def extract_lora_metadata(file_path: str) -> Dict:
|
|||||||
base_model = determine_base_model(metadata.get("ss_base_model_version"))
|
base_model = determine_base_model(metadata.get("ss_base_model_version"))
|
||||||
return {"base_model": base_model}
|
return {"base_model": base_model}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading metadata from {file_path}: {str(e)}")
|
logger.error(f"Error reading metadata from {file_path}: {str(e)}")
|
||||||
return {"base_model": "Unknown"}
|
return {"base_model": "Unknown"}
|
||||||
|
|
||||||
async def extract_checkpoint_metadata(file_path: str) -> dict:
|
async def extract_checkpoint_metadata(file_path: str) -> dict:
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class MetadataManager:
|
|||||||
preview_url=normalize_path(preview_url),
|
preview_url=normalize_path(preview_url),
|
||||||
tags=[],
|
tags=[],
|
||||||
modelDescription="",
|
modelDescription="",
|
||||||
model_type="checkpoint",
|
sub_type="checkpoint",
|
||||||
from_civitai=True
|
from_civitai=True
|
||||||
)
|
)
|
||||||
elif model_class.__name__ == "EmbeddingMetadata":
|
elif model_class.__name__ == "EmbeddingMetadata":
|
||||||
@@ -238,6 +238,7 @@ class MetadataManager:
|
|||||||
preview_url=normalize_path(preview_url),
|
preview_url=normalize_path(preview_url),
|
||||||
tags=[],
|
tags=[],
|
||||||
modelDescription="",
|
modelDescription="",
|
||||||
|
sub_type="embedding",
|
||||||
from_civitai=True
|
from_civitai=True
|
||||||
)
|
)
|
||||||
else: # Default to LoraMetadata
|
else: # Default to LoraMetadata
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-lora-manager"
|
name = "comfyui-lora-manager"
|
||||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||||
version = "0.9.13"
|
version = "0.9.14"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|||||||
0
scripts/sync_translation_keys.py
Normal file → Executable file
0
scripts/sync_translation_keys.py
Normal file → Executable file
@@ -113,6 +113,12 @@
|
|||||||
max-width: 110px;
|
max-width: 110px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Compact mode: hide sub-type to save space */
|
||||||
|
.compact-density .model-sub-type,
|
||||||
|
.compact-density .model-separator {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
.compact-density .card-actions i {
|
.compact-density .card-actions i {
|
||||||
font-size: 0.95em;
|
font-size: 0.95em;
|
||||||
padding: 3px;
|
padding: 3px;
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class RecipeCard {
|
|||||||
card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0;
|
card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0;
|
||||||
card.dataset.created = this.recipe.created_date;
|
card.dataset.created = this.recipe.created_date;
|
||||||
card.dataset.id = this.recipe.id || '';
|
card.dataset.id = this.recipe.id || '';
|
||||||
|
card.dataset.folder = this.recipe.folder || '';
|
||||||
|
|
||||||
// Get base model with fallback
|
// Get base model with fallback
|
||||||
const baseModelLabel = (this.recipe.base_model || '').trim() || 'Unknown';
|
const baseModelLabel = (this.recipe.base_model || '').trim() || 'Unknown';
|
||||||
|
|||||||
@@ -199,6 +199,12 @@ class InitializationManager {
|
|||||||
if (!data) return;
|
if (!data) return;
|
||||||
console.log('Received progress update:', data);
|
console.log('Received progress update:', data);
|
||||||
|
|
||||||
|
// Handle cache health warning messages
|
||||||
|
if (data.type === 'cache_health_warning') {
|
||||||
|
this.handleCacheHealthWarning(data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this update is for our page type
|
// Check if this update is for our page type
|
||||||
if (data.pageType && data.pageType !== this.pageType) {
|
if (data.pageType && data.pageType !== this.pageType) {
|
||||||
console.log(`Ignoring update for ${data.pageType}, we're on ${this.pageType}`);
|
console.log(`Ignoring update for ${data.pageType}, we're on ${this.pageType}`);
|
||||||
@@ -466,6 +472,29 @@ class InitializationManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle cache health warning messages from WebSocket
|
||||||
|
*/
|
||||||
|
handleCacheHealthWarning(data) {
|
||||||
|
console.log('Cache health warning received:', data);
|
||||||
|
|
||||||
|
// Import bannerService dynamically to avoid circular dependencies
|
||||||
|
import('../managers/BannerService.js').then(({ bannerService }) => {
|
||||||
|
// Initialize banner service if not already done
|
||||||
|
if (!bannerService.initialized) {
|
||||||
|
bannerService.initialize().then(() => {
|
||||||
|
bannerService.registerCacheHealthBanner(data);
|
||||||
|
}).catch(err => {
|
||||||
|
console.error('Failed to initialize banner service:', err);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
bannerService.registerCacheHealthBanner(data);
|
||||||
|
}
|
||||||
|
}).catch(err => {
|
||||||
|
console.error('Failed to load banner service:', err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean up resources when the component is destroyed
|
* Clean up resources when the component is destroyed
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import {
|
|||||||
removeStorageItem
|
removeStorageItem
|
||||||
} from '../utils/storageHelpers.js';
|
} from '../utils/storageHelpers.js';
|
||||||
import { translate } from '../utils/i18nHelpers.js';
|
import { translate } from '../utils/i18nHelpers.js';
|
||||||
import { state } from '../state/index.js'
|
import { state } from '../state/index.js';
|
||||||
|
import { getModelApiClient } from '../api/modelApiFactory.js';
|
||||||
|
|
||||||
const COMMUNITY_SUPPORT_BANNER_ID = 'community-support';
|
const COMMUNITY_SUPPORT_BANNER_ID = 'community-support';
|
||||||
|
const CACHE_HEALTH_BANNER_ID = 'cache-health-warning';
|
||||||
const COMMUNITY_SUPPORT_BANNER_DELAY_MS = 5 * 24 * 60 * 60 * 1000; // 5 days
|
const COMMUNITY_SUPPORT_BANNER_DELAY_MS = 5 * 24 * 60 * 60 * 1000; // 5 days
|
||||||
const COMMUNITY_SUPPORT_FIRST_SEEN_AT_KEY = 'community_support_banner_first_seen_at';
|
const COMMUNITY_SUPPORT_FIRST_SEEN_AT_KEY = 'community_support_banner_first_seen_at';
|
||||||
const COMMUNITY_SUPPORT_VERSION_KEY = 'community_support_banner_state_version';
|
const COMMUNITY_SUPPORT_VERSION_KEY = 'community_support_banner_state_version';
|
||||||
@@ -293,6 +295,177 @@ class BannerService {
|
|||||||
location.reload();
|
location.reload();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a cache health warning banner
|
||||||
|
* @param {Object} healthData - Health data from WebSocket
|
||||||
|
*/
|
||||||
|
registerCacheHealthBanner(healthData) {
|
||||||
|
if (!healthData || healthData.status === 'healthy') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove existing cache health banner if any
|
||||||
|
this.removeBannerElement(CACHE_HEALTH_BANNER_ID);
|
||||||
|
|
||||||
|
const isCorrupted = healthData.status === 'corrupted';
|
||||||
|
const titleKey = isCorrupted
|
||||||
|
? 'banners.cacheHealth.corrupted.title'
|
||||||
|
: 'banners.cacheHealth.degraded.title';
|
||||||
|
const defaultTitle = isCorrupted
|
||||||
|
? 'Cache Corruption Detected'
|
||||||
|
: 'Cache Issues Detected';
|
||||||
|
|
||||||
|
const title = translate(titleKey, {}, defaultTitle);
|
||||||
|
|
||||||
|
const contentKey = 'banners.cacheHealth.content';
|
||||||
|
const defaultContent = 'Found {invalid} of {total} cache entries are invalid ({rate}). This may cause missing models or errors. Rebuilding the cache is recommended.';
|
||||||
|
const content = translate(contentKey, {
|
||||||
|
invalid: healthData.details?.invalid || 0,
|
||||||
|
total: healthData.details?.total || 0,
|
||||||
|
rate: healthData.details?.corruption_rate || '0%'
|
||||||
|
}, defaultContent);
|
||||||
|
|
||||||
|
this.registerBanner(CACHE_HEALTH_BANNER_ID, {
|
||||||
|
id: CACHE_HEALTH_BANNER_ID,
|
||||||
|
title: title,
|
||||||
|
content: content,
|
||||||
|
pageType: healthData.pageType,
|
||||||
|
actions: [
|
||||||
|
{
|
||||||
|
text: translate('banners.cacheHealth.rebuildCache', {}, 'Rebuild Cache'),
|
||||||
|
icon: 'fas fa-sync-alt',
|
||||||
|
action: 'rebuild-cache',
|
||||||
|
type: 'primary'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
text: translate('banners.cacheHealth.dismiss', {}, 'Dismiss'),
|
||||||
|
icon: 'fas fa-times',
|
||||||
|
action: 'dismiss',
|
||||||
|
type: 'secondary'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
dismissible: true,
|
||||||
|
priority: 10, // High priority
|
||||||
|
onRegister: (bannerElement) => {
|
||||||
|
// Attach click handlers for actions
|
||||||
|
const rebuildBtn = bannerElement.querySelector('[data-action="rebuild-cache"]');
|
||||||
|
const dismissBtn = bannerElement.querySelector('[data-action="dismiss"]');
|
||||||
|
|
||||||
|
if (rebuildBtn) {
|
||||||
|
rebuildBtn.addEventListener('click', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
this.handleRebuildCache(bannerElement, healthData.pageType);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dismissBtn) {
|
||||||
|
dismissBtn.addEventListener('click', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
this.dismissBanner(CACHE_HEALTH_BANNER_ID);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle rebuild cache action from banner
|
||||||
|
* @param {HTMLElement} bannerElement - The banner element
|
||||||
|
* @param {string} pageType - The page type (loras, checkpoints, embeddings)
|
||||||
|
*/
|
||||||
|
async handleRebuildCache(bannerElement, pageType) {
|
||||||
|
const currentPageType = pageType || this.getCurrentPageType();
|
||||||
|
|
||||||
|
try {
|
||||||
|
const apiClient = getModelApiClient(currentPageType);
|
||||||
|
|
||||||
|
// Update banner to show rebuilding status
|
||||||
|
const actionsContainer = bannerElement.querySelector('.banner-actions');
|
||||||
|
if (actionsContainer) {
|
||||||
|
actionsContainer.innerHTML = `
|
||||||
|
<span class="banner-loading">
|
||||||
|
<i class="fas fa-spinner fa-spin"></i>
|
||||||
|
<span>${translate('banners.cacheHealth.rebuilding', {}, 'Rebuilding cache...')}</span>
|
||||||
|
</span>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
await apiClient.refreshModels(true);
|
||||||
|
|
||||||
|
// Remove banner on success without marking as dismissed
|
||||||
|
this.removeBannerElement(CACHE_HEALTH_BANNER_ID);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Cache rebuild failed:', error);
|
||||||
|
|
||||||
|
const actionsContainer = bannerElement.querySelector('.banner-actions');
|
||||||
|
if (actionsContainer) {
|
||||||
|
actionsContainer.innerHTML = `
|
||||||
|
<span class="banner-error">
|
||||||
|
<i class="fas fa-exclamation-triangle"></i>
|
||||||
|
<span>${translate('banners.cacheHealth.rebuildFailed', {}, 'Rebuild failed. Please try again.')}</span>
|
||||||
|
</span>
|
||||||
|
<a href="#" class="banner-action banner-action-primary" data-action="rebuild-cache">
|
||||||
|
<i class="fas fa-sync-alt"></i>
|
||||||
|
<span>${translate('banners.cacheHealth.retry', {}, 'Retry')}</span>
|
||||||
|
</a>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// Re-attach click handler
|
||||||
|
const retryBtn = actionsContainer.querySelector('[data-action="rebuild-cache"]');
|
||||||
|
if (retryBtn) {
|
||||||
|
retryBtn.addEventListener('click', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
this.handleRebuildCache(bannerElement, pageType);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the current page type from the URL
|
||||||
|
* @returns {string} Page type (loras, checkpoints, embeddings, recipes)
|
||||||
|
*/
|
||||||
|
getCurrentPageType() {
|
||||||
|
const path = window.location.pathname;
|
||||||
|
if (path.includes('/checkpoints')) return 'checkpoints';
|
||||||
|
if (path.includes('/embeddings')) return 'embeddings';
|
||||||
|
if (path.includes('/recipes')) return 'recipes';
|
||||||
|
return 'loras';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the rebuild cache endpoint for the given page type
|
||||||
|
* @param {string} pageType - The page type
|
||||||
|
* @returns {string} The API endpoint URL
|
||||||
|
*/
|
||||||
|
getRebuildEndpoint(pageType) {
|
||||||
|
const endpoints = {
|
||||||
|
'loras': '/api/lm/loras/reload?rebuild=true',
|
||||||
|
'checkpoints': '/api/lm/checkpoints/reload?rebuild=true',
|
||||||
|
'embeddings': '/api/lm/embeddings/reload?rebuild=true'
|
||||||
|
};
|
||||||
|
return endpoints[pageType] || endpoints['loras'];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a banner element from DOM without marking as dismissed
|
||||||
|
* @param {string} bannerId - Banner ID to remove
|
||||||
|
*/
|
||||||
|
removeBannerElement(bannerId) {
|
||||||
|
const bannerElement = document.querySelector(`[data-banner-id="${bannerId}"]`);
|
||||||
|
if (bannerElement) {
|
||||||
|
bannerElement.style.animation = 'banner-slide-up 0.3s ease-in-out forwards';
|
||||||
|
setTimeout(() => {
|
||||||
|
bannerElement.remove();
|
||||||
|
this.updateContainerVisibility();
|
||||||
|
}, 300);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also remove from banners map
|
||||||
|
this.banners.delete(bannerId);
|
||||||
|
}
|
||||||
|
|
||||||
prepareCommunitySupportBanner() {
|
prepareCommunitySupportBanner() {
|
||||||
if (this.isBannerDismissed(COMMUNITY_SUPPORT_BANNER_ID)) {
|
if (this.isBannerDismissed(COMMUNITY_SUPPORT_BANNER_ID)) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ export class ExampleImagesManager {
|
|||||||
// Auto download properties
|
// Auto download properties
|
||||||
this.autoDownloadInterval = null;
|
this.autoDownloadInterval = null;
|
||||||
this.lastAutoDownloadCheck = 0;
|
this.lastAutoDownloadCheck = 0;
|
||||||
this.autoDownloadCheckInterval = 10 * 60 * 1000; // 10 minutes in milliseconds
|
this.autoDownloadCheckInterval = 30 * 60 * 1000; // 30 minutes in milliseconds
|
||||||
this.pageInitTime = Date.now(); // Track when page was initialized
|
this.pageInitTime = Date.now(); // Track when page was initialized
|
||||||
|
|
||||||
// Initialize download path field and check download status
|
// Initialize download path field and check download status
|
||||||
@@ -808,19 +808,58 @@ export class ExampleImagesManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.lastAutoDownloadCheck = now;
|
|
||||||
|
|
||||||
if (!this.canAutoDownload()) {
|
if (!this.canAutoDownload()) {
|
||||||
console.log('Auto download conditions not met, skipping check');
|
console.log('Auto download conditions not met, skipping check');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log('Performing auto download check...');
|
console.log('Performing auto download pre-check...');
|
||||||
|
|
||||||
|
// Step 1: Lightweight pre-check to see if any work is needed
|
||||||
|
const checkResponse = await fetch('/api/lm/check-example-images-needed', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model_types: ['lora', 'checkpoint', 'embedding']
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!checkResponse.ok) {
|
||||||
|
console.warn('Auto download pre-check HTTP error:', checkResponse.status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const checkData = await checkResponse.json();
|
||||||
|
|
||||||
|
if (!checkData.success) {
|
||||||
|
console.warn('Auto download pre-check failed:', checkData.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the check timestamp only after successful pre-check
|
||||||
|
this.lastAutoDownloadCheck = now;
|
||||||
|
|
||||||
|
// If download already in progress, skip
|
||||||
|
if (checkData.is_downloading) {
|
||||||
|
console.log('Download already in progress, skipping auto check');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no models need downloading, skip
|
||||||
|
if (!checkData.needs_download || checkData.pending_count === 0) {
|
||||||
|
console.log(`Auto download pre-check complete: ${checkData.processed_count}/${checkData.total_models} models already processed, no work needed`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Auto download pre-check: ${checkData.pending_count} models need processing, starting download...`);
|
||||||
|
|
||||||
|
// Step 2: Start the actual download (fire-and-forget)
|
||||||
const optimize = state.global.settings.optimize_example_images;
|
const optimize = state.global.settings.optimize_example_images;
|
||||||
|
|
||||||
const response = await fetch('/api/lm/download-example-images', {
|
fetch('/api/lm/download-example-images', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
@@ -830,18 +869,29 @@ export class ExampleImagesManager {
|
|||||||
model_types: ['lora', 'checkpoint', 'embedding'],
|
model_types: ['lora', 'checkpoint', 'embedding'],
|
||||||
auto_mode: true // Flag to indicate this is an automatic download
|
auto_mode: true // Flag to indicate this is an automatic download
|
||||||
})
|
})
|
||||||
|
}).then(response => {
|
||||||
|
if (!response.ok) {
|
||||||
|
console.warn('Auto download start HTTP error:', response.status);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return response.json();
|
||||||
|
}).then(data => {
|
||||||
|
if (data && !data.success) {
|
||||||
|
console.warn('Auto download start failed:', data.error);
|
||||||
|
// If already in progress, push back the next check to avoid hammering the API
|
||||||
|
if (data.error && data.error.includes('already in progress')) {
|
||||||
|
console.log('Download already in progress, backing off next check');
|
||||||
|
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
|
||||||
|
}
|
||||||
|
} else if (data && data.success) {
|
||||||
|
console.log('Auto download started:', data.message || 'Download started');
|
||||||
|
}
|
||||||
|
}).catch(error => {
|
||||||
|
console.error('Auto download start error:', error);
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
// Immediately return without waiting for the download fetch to complete
|
||||||
|
// This keeps the UI responsive
|
||||||
if (!data.success) {
|
|
||||||
console.warn('Auto download check failed:', data.error);
|
|
||||||
// If already in progress, push back the next check to avoid hammering the API
|
|
||||||
if (data.error && data.error.includes('already in progress')) {
|
|
||||||
console.log('Download already in progress, backing off next check');
|
|
||||||
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Auto download check error:', error);
|
console.error('Auto download check error:', error);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ export const BASE_MODELS = {
|
|||||||
FLUX_1_KREA: "Flux.1 Krea",
|
FLUX_1_KREA: "Flux.1 Krea",
|
||||||
FLUX_1_KONTEXT: "Flux.1 Kontext",
|
FLUX_1_KONTEXT: "Flux.1 Kontext",
|
||||||
FLUX_2_D: "Flux.2 D",
|
FLUX_2_D: "Flux.2 D",
|
||||||
|
FLUX_2_KLEIN_9B: "Flux.2 Klein 9B",
|
||||||
|
FLUX_2_KLEIN_9B_BASE: "Flux.2 Klein 9B-base",
|
||||||
|
FLUX_2_KLEIN_4B: "Flux.2 Klein 4B",
|
||||||
|
FLUX_2_KLEIN_4B_BASE: "Flux.2 Klein 4B-base",
|
||||||
AURAFLOW: "AuraFlow",
|
AURAFLOW: "AuraFlow",
|
||||||
CHROMA: "Chroma",
|
CHROMA: "Chroma",
|
||||||
PIXART_A: "PixArt a",
|
PIXART_A: "PixArt a",
|
||||||
@@ -40,10 +44,12 @@ export const BASE_MODELS = {
|
|||||||
HIDREAM: "HiDream",
|
HIDREAM: "HiDream",
|
||||||
QWEN: "Qwen",
|
QWEN: "Qwen",
|
||||||
ZIMAGE_TURBO: "ZImageTurbo",
|
ZIMAGE_TURBO: "ZImageTurbo",
|
||||||
|
ZIMAGE_BASE: "ZImageBase",
|
||||||
|
|
||||||
// Video models
|
// Video models
|
||||||
SVD: "SVD",
|
SVD: "SVD",
|
||||||
LTXV: "LTXV",
|
LTXV: "LTXV",
|
||||||
|
LTXV2: "LTXV2",
|
||||||
WAN_VIDEO: "Wan Video",
|
WAN_VIDEO: "Wan Video",
|
||||||
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
|
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
|
||||||
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
|
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
|
||||||
@@ -120,6 +126,10 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
|||||||
[BASE_MODELS.FLUX_1_KREA]: 'F1KR',
|
[BASE_MODELS.FLUX_1_KREA]: 'F1KR',
|
||||||
[BASE_MODELS.FLUX_1_KONTEXT]: 'F1KX',
|
[BASE_MODELS.FLUX_1_KONTEXT]: 'F1KX',
|
||||||
[BASE_MODELS.FLUX_2_D]: 'F2D',
|
[BASE_MODELS.FLUX_2_D]: 'F2D',
|
||||||
|
[BASE_MODELS.FLUX_2_KLEIN_9B]: 'FK9',
|
||||||
|
[BASE_MODELS.FLUX_2_KLEIN_9B_BASE]: 'FK9B',
|
||||||
|
[BASE_MODELS.FLUX_2_KLEIN_4B]: 'FK4',
|
||||||
|
[BASE_MODELS.FLUX_2_KLEIN_4B_BASE]: 'FK4B',
|
||||||
|
|
||||||
// Other diffusion models
|
// Other diffusion models
|
||||||
[BASE_MODELS.AURAFLOW]: 'AF',
|
[BASE_MODELS.AURAFLOW]: 'AF',
|
||||||
@@ -135,10 +145,12 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
|||||||
[BASE_MODELS.HIDREAM]: 'HID',
|
[BASE_MODELS.HIDREAM]: 'HID',
|
||||||
[BASE_MODELS.QWEN]: 'QWEN',
|
[BASE_MODELS.QWEN]: 'QWEN',
|
||||||
[BASE_MODELS.ZIMAGE_TURBO]: 'ZIT',
|
[BASE_MODELS.ZIMAGE_TURBO]: 'ZIT',
|
||||||
|
[BASE_MODELS.ZIMAGE_BASE]: 'ZIB',
|
||||||
|
|
||||||
// Video models
|
// Video models
|
||||||
[BASE_MODELS.SVD]: 'SVD',
|
[BASE_MODELS.SVD]: 'SVD',
|
||||||
[BASE_MODELS.LTXV]: 'LTXV',
|
[BASE_MODELS.LTXV]: 'LTXV',
|
||||||
|
[BASE_MODELS.LTXV2]: 'LTV2',
|
||||||
[BASE_MODELS.WAN_VIDEO]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO]: 'WAN',
|
||||||
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
|
||||||
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
|
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
|
||||||
@@ -328,16 +340,16 @@ export const BASE_MODEL_CATEGORIES = {
|
|||||||
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
|
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
|
||||||
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
||||||
'Video Models': [
|
'Video Models': [
|
||||||
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO,
|
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO,
|
||||||
BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
|
BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
|
||||||
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
|
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
|
||||||
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
|
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
|
||||||
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B
|
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B
|
||||||
],
|
],
|
||||||
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D],
|
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE],
|
||||||
'Other Models': [
|
'Other Models': [
|
||||||
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
||||||
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO,
|
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE,
|
||||||
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
||||||
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
||||||
BASE_MODELS.UNKNOWN
|
BASE_MODELS.UNKNOWN
|
||||||
|
|||||||
@@ -230,8 +230,58 @@ def test_new_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
|||||||
assert normalized_external in second_cfg._path_mappings
|
assert normalized_external in second_cfg._path_mappings
|
||||||
|
|
||||||
|
|
||||||
def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
def test_removed_first_level_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
"""Removing a deep symlink should trigger cache invalidation."""
|
"""Removing a first-level symlink should trigger cache invalidation."""
|
||||||
|
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
# Create first-level symlink (directly under loras root)
|
||||||
|
external_dir = tmp_path / "external"
|
||||||
|
external_dir.mkdir()
|
||||||
|
symlink = loras_dir / "external_models"
|
||||||
|
symlink.symlink_to(external_dir, target_is_directory=True)
|
||||||
|
|
||||||
|
# Initial scan finds the symlink
|
||||||
|
first_cfg = config_module.Config()
|
||||||
|
normalized_external = _normalize(str(external_dir))
|
||||||
|
assert normalized_external in first_cfg._path_mappings
|
||||||
|
|
||||||
|
# Remove the symlink
|
||||||
|
symlink.unlink()
|
||||||
|
|
||||||
|
# Second config should detect invalid cached mapping and rescan
|
||||||
|
second_cfg = config_module.Config()
|
||||||
|
assert normalized_external not in second_cfg._path_mappings
|
||||||
|
|
||||||
|
|
||||||
|
def test_retargeted_first_level_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
|
"""Changing a first-level symlink's target should trigger cache invalidation."""
|
||||||
|
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
# Create first-level symlink
|
||||||
|
target_v1 = tmp_path / "external_v1"
|
||||||
|
target_v1.mkdir()
|
||||||
|
target_v2 = tmp_path / "external_v2"
|
||||||
|
target_v2.mkdir()
|
||||||
|
|
||||||
|
symlink = loras_dir / "external_models"
|
||||||
|
symlink.symlink_to(target_v1, target_is_directory=True)
|
||||||
|
|
||||||
|
# Initial scan
|
||||||
|
first_cfg = config_module.Config()
|
||||||
|
assert _normalize(str(target_v1)) in first_cfg._path_mappings
|
||||||
|
|
||||||
|
# Retarget the symlink
|
||||||
|
symlink.unlink()
|
||||||
|
symlink.symlink_to(target_v2, target_is_directory=True)
|
||||||
|
|
||||||
|
# Second config should detect changed target and rescan
|
||||||
|
second_cfg = config_module.Config()
|
||||||
|
assert _normalize(str(target_v2)) in second_cfg._path_mappings
|
||||||
|
assert _normalize(str(target_v1)) not in second_cfg._path_mappings
|
||||||
|
|
||||||
|
|
||||||
|
def test_deep_symlink_not_scanned(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
|
"""Deep symlinks (below first level) are not scanned to avoid performance issues."""
|
||||||
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||||
|
|
||||||
# Create nested structure with deep symlink
|
# Create nested structure with deep symlink
|
||||||
@@ -242,46 +292,12 @@ def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, t
|
|||||||
deep_symlink = subdir / "styles"
|
deep_symlink = subdir / "styles"
|
||||||
deep_symlink.symlink_to(external_dir, target_is_directory=True)
|
deep_symlink.symlink_to(external_dir, target_is_directory=True)
|
||||||
|
|
||||||
# Initial scan finds the deep symlink
|
# Config should not detect deep symlinks (only first-level)
|
||||||
first_cfg = config_module.Config()
|
cfg = config_module.Config()
|
||||||
normalized_external = _normalize(str(external_dir))
|
normalized_external = _normalize(str(external_dir))
|
||||||
assert normalized_external in first_cfg._path_mappings
|
assert normalized_external not in cfg._path_mappings
|
||||||
|
|
||||||
# Remove the deep symlink
|
|
||||||
deep_symlink.unlink()
|
|
||||||
|
|
||||||
# Second config should detect invalid cached mapping and rescan
|
|
||||||
second_cfg = config_module.Config()
|
|
||||||
assert normalized_external not in second_cfg._path_mappings
|
|
||||||
|
|
||||||
|
|
||||||
def test_retargeted_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
|
||||||
"""Changing a deep symlink's target should trigger cache invalidation."""
|
|
||||||
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
|
||||||
|
|
||||||
# Create nested structure
|
|
||||||
subdir = loras_dir / "anime"
|
|
||||||
subdir.mkdir()
|
|
||||||
target_v1 = tmp_path / "external_v1"
|
|
||||||
target_v1.mkdir()
|
|
||||||
target_v2 = tmp_path / "external_v2"
|
|
||||||
target_v2.mkdir()
|
|
||||||
|
|
||||||
deep_symlink = subdir / "styles"
|
|
||||||
deep_symlink.symlink_to(target_v1, target_is_directory=True)
|
|
||||||
|
|
||||||
# Initial scan
|
|
||||||
first_cfg = config_module.Config()
|
|
||||||
assert _normalize(str(target_v1)) in first_cfg._path_mappings
|
|
||||||
|
|
||||||
# Retarget the symlink
|
|
||||||
deep_symlink.unlink()
|
|
||||||
deep_symlink.symlink_to(target_v2, target_is_directory=True)
|
|
||||||
|
|
||||||
# Second config should detect changed target and rescan
|
|
||||||
second_cfg = config_module.Config()
|
|
||||||
assert _normalize(str(target_v2)) in second_cfg._path_mappings
|
|
||||||
assert _normalize(str(target_v1)) not in second_cfg._path_mappings
|
|
||||||
def test_legacy_symlink_cache_automatic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
def test_legacy_symlink_cache_automatic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
"""Test that legacy symlink cache is automatically cleaned up after migration."""
|
"""Test that legacy symlink cache is automatically cleaned up after migration."""
|
||||||
settings_dir = tmp_path / "settings"
|
settings_dir = tmp_path / "settings"
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ class StubDownloadManager:
|
|||||||
self.resume_error: Exception | None = None
|
self.resume_error: Exception | None = None
|
||||||
self.stop_error: Exception | None = None
|
self.stop_error: Exception | None = None
|
||||||
self.force_error: Exception | None = None
|
self.force_error: Exception | None = None
|
||||||
|
self.check_pending_result: dict[str, Any] | None = None
|
||||||
|
self.check_pending_calls: list[list[str]] = []
|
||||||
|
|
||||||
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
||||||
return {"success": True, "status": "idle"}
|
return {"success": True, "status": "idle"}
|
||||||
@@ -75,6 +77,20 @@ class StubDownloadManager:
|
|||||||
raise self.force_error
|
raise self.force_error
|
||||||
return {"success": True, "payload": payload}
|
return {"success": True, "payload": payload}
|
||||||
|
|
||||||
|
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
|
||||||
|
self.check_pending_calls.append(model_types)
|
||||||
|
if self.check_pending_result is not None:
|
||||||
|
return self.check_pending_result
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": 100,
|
||||||
|
"pending_count": 10,
|
||||||
|
"processed_count": 90,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class StubImportUseCase:
|
class StubImportUseCase:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors():
|
|||||||
assert response.status == 400
|
assert response.status == 400
|
||||||
body = await _json(response)
|
body = await _json(response)
|
||||||
assert body == {"success": False, "error": "bad payload"}
|
assert body == {"success": False, "error": "bad payload"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_returns_pending_counts():
|
||||||
|
"""Test that check_example_images_needed endpoint returns pending model counts."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
harness.download_manager.check_pending_result = {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": 5500,
|
||||||
|
"pending_count": 12,
|
||||||
|
"processed_count": 5488,
|
||||||
|
"failed_count": 45,
|
||||||
|
"needs_download": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora", "checkpoint"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["total_models"] == 5500
|
||||||
|
assert body["pending_count"] == 12
|
||||||
|
assert body["processed_count"] == 5488
|
||||||
|
assert body["failed_count"] == 45
|
||||||
|
assert body["needs_download"] is True
|
||||||
|
assert body["is_downloading"] is False
|
||||||
|
|
||||||
|
# Verify the manager was called with correct model types
|
||||||
|
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_handles_download_in_progress():
|
||||||
|
"""Test that check_example_images_needed returns correct status when download is running."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
harness.download_manager.check_pending_result = {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": True,
|
||||||
|
"total_models": 0,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
"message": "Download already in progress",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["is_downloading"] is True
|
||||||
|
assert body["needs_download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_handles_no_pending_models():
|
||||||
|
"""Test that check_example_images_needed returns correct status when no work is needed."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
harness.download_manager.check_pending_result = {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": 5500,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 5500,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora", "checkpoint", "embedding"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["pending_count"] == 0
|
||||||
|
assert body["needs_download"] is False
|
||||||
|
assert body["processed_count"] == 5500
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_uses_default_model_types():
|
||||||
|
"""Test that check_example_images_needed uses default model types when not specified."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={}, # No model_types specified
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
# Should use default model types
|
||||||
|
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_returns_error_on_exception():
|
||||||
|
"""Test that check_example_images_needed returns 500 on internal error."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
# Simulate an error by setting result to an error state
|
||||||
|
# Actually, we need to make the method raise an exception
|
||||||
|
original_method = harness.download_manager.check_pending_models
|
||||||
|
|
||||||
|
async def failing_check(_model_types):
|
||||||
|
raise RuntimeError("Database connection failed")
|
||||||
|
|
||||||
|
harness.download_manager.check_pending_models = failing_check
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 500
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "Database connection failed" in body["error"]
|
||||||
|
|||||||
@@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
|||||||
"resume_example_images",
|
"resume_example_images",
|
||||||
"stop_example_images",
|
"stop_example_images",
|
||||||
"force_download_example_images",
|
"force_download_example_images",
|
||||||
|
"check_example_images_needed",
|
||||||
"import_example_images",
|
"import_example_images",
|
||||||
"delete_example_image",
|
"delete_example_image",
|
||||||
"set_example_image_nsfw_level",
|
"set_example_image_nsfw_level",
|
||||||
|
|||||||
283
tests/services/test_cache_entry_validator.py
Normal file
283
tests/services/test_cache_entry_validator.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for CacheEntryValidator
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.cache_entry_validator import (
|
||||||
|
CacheEntryValidator,
|
||||||
|
ValidationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheEntryValidator:
|
||||||
|
"""Tests for CacheEntryValidator class"""
|
||||||
|
|
||||||
|
def test_validate_valid_entry(self):
|
||||||
|
"""Test validation of a valid cache entry"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
'file_name': 'test.safetensors',
|
||||||
|
'model_name': 'Test Model',
|
||||||
|
'size': 1024,
|
||||||
|
'modified': 1234567890.0,
|
||||||
|
'tags': ['tag1', 'tag2'],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.repaired is False
|
||||||
|
assert len(result.errors) == 0
|
||||||
|
assert result.entry == entry
|
||||||
|
|
||||||
|
def test_validate_missing_required_field_sha256(self):
|
||||||
|
"""Test validation fails when required sha256 field is missing"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
# sha256 missing
|
||||||
|
'file_name': 'test.safetensors',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('sha256' in error for error in result.errors)
|
||||||
|
|
||||||
|
def test_validate_missing_required_field_file_path(self):
|
||||||
|
"""Test validation fails when required file_path field is missing"""
|
||||||
|
entry = {
|
||||||
|
# file_path missing
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
'file_name': 'test.safetensors',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('file_path' in error for error in result.errors)
|
||||||
|
|
||||||
|
def test_validate_empty_required_field_sha256(self):
|
||||||
|
"""Test validation fails when sha256 is empty string"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': '', # Empty string
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('sha256' in error for error in result.errors)
|
||||||
|
|
||||||
|
def test_validate_empty_required_field_file_path(self):
|
||||||
|
"""Test validation fails when file_path is empty string"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '', # Empty string
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('file_path' in error for error in result.errors)
|
||||||
|
|
||||||
|
def test_validate_none_required_field(self):
|
||||||
|
"""Test validation fails when required field is None"""
|
||||||
|
entry = {
|
||||||
|
'file_path': None,
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('file_path' in error for error in result.errors)
|
||||||
|
|
||||||
|
def test_validate_none_entry(self):
|
||||||
|
"""Test validation handles None entry"""
|
||||||
|
result = CacheEntryValidator.validate(None, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('None' in error for error in result.errors)
|
||||||
|
assert result.entry is None
|
||||||
|
|
||||||
|
def test_validate_non_dict_entry(self):
|
||||||
|
"""Test validation handles non-dict entry"""
|
||||||
|
result = CacheEntryValidator.validate("not a dict", auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert result.repaired is False
|
||||||
|
assert any('not a dict' in error for error in result.errors)
|
||||||
|
assert result.entry is None
|
||||||
|
|
||||||
|
def test_auto_repair_missing_non_required_field(self):
|
||||||
|
"""Test auto-repair adds missing non-required fields"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
# file_name, model_name, tags missing
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=True)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.repaired is True
|
||||||
|
assert result.entry['file_name'] == ''
|
||||||
|
assert result.entry['model_name'] == ''
|
||||||
|
assert result.entry['tags'] == []
|
||||||
|
|
||||||
|
def test_auto_repair_wrong_type_field(self):
|
||||||
|
"""Test auto-repair fixes fields with wrong type"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
'size': 'not a number', # Should be int
|
||||||
|
'tags': 'not a list', # Should be list
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=True)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.repaired is True
|
||||||
|
assert result.entry['size'] == 0 # Default value
|
||||||
|
assert result.entry['tags'] == [] # Default value
|
||||||
|
|
||||||
|
def test_normalize_sha256_lowercase(self):
|
||||||
|
"""Test sha256 is normalized to lowercase"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'ABC123DEF456', # Uppercase
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=True)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.entry['sha256'] == 'abc123def456'
|
||||||
|
|
||||||
|
def test_validate_batch_all_valid(self):
|
||||||
|
"""Test batch validation with all valid entries"""
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': '/models/test1.safetensors',
|
||||||
|
'sha256': 'abc123',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': '/models/test2.safetensors',
|
||||||
|
'sha256': 'def456',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert len(valid) == 2
|
||||||
|
assert len(invalid) == 0
|
||||||
|
|
||||||
|
def test_validate_batch_mixed_validity(self):
|
||||||
|
"""Test batch validation with mixed valid/invalid entries"""
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': '/models/test1.safetensors',
|
||||||
|
'sha256': 'abc123',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': '/models/test2.safetensors',
|
||||||
|
# sha256 missing - invalid
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': '/models/test3.safetensors',
|
||||||
|
'sha256': 'def456',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert len(valid) == 2
|
||||||
|
assert len(invalid) == 1
|
||||||
|
# invalid list contains the actual invalid entries (not by index)
|
||||||
|
assert invalid[0]['file_path'] == '/models/test2.safetensors'
|
||||||
|
|
||||||
|
def test_validate_batch_empty_list(self):
|
||||||
|
"""Test batch validation with empty list"""
|
||||||
|
valid, invalid = CacheEntryValidator.validate_batch([], auto_repair=False)
|
||||||
|
|
||||||
|
assert len(valid) == 0
|
||||||
|
assert len(invalid) == 0
|
||||||
|
|
||||||
|
def test_get_file_path_safe(self):
|
||||||
|
"""Test safe file_path extraction"""
|
||||||
|
entry = {'file_path': '/models/test.safetensors', 'sha256': 'abc123'}
|
||||||
|
assert CacheEntryValidator.get_file_path_safe(entry) == '/models/test.safetensors'
|
||||||
|
|
||||||
|
def test_get_file_path_safe_missing(self):
|
||||||
|
"""Test safe file_path extraction when missing"""
|
||||||
|
entry = {'sha256': 'abc123'}
|
||||||
|
assert CacheEntryValidator.get_file_path_safe(entry) == ''
|
||||||
|
|
||||||
|
def test_get_file_path_safe_not_dict(self):
|
||||||
|
"""Test safe file_path extraction from non-dict"""
|
||||||
|
assert CacheEntryValidator.get_file_path_safe(None) == ''
|
||||||
|
assert CacheEntryValidator.get_file_path_safe('string') == ''
|
||||||
|
|
||||||
|
def test_get_sha256_safe(self):
|
||||||
|
"""Test safe sha256 extraction"""
|
||||||
|
entry = {'file_path': '/models/test.safetensors', 'sha256': 'ABC123'}
|
||||||
|
assert CacheEntryValidator.get_sha256_safe(entry) == 'abc123'
|
||||||
|
|
||||||
|
def test_get_sha256_safe_missing(self):
|
||||||
|
"""Test safe sha256 extraction when missing"""
|
||||||
|
entry = {'file_path': '/models/test.safetensors'}
|
||||||
|
assert CacheEntryValidator.get_sha256_safe(entry) == ''
|
||||||
|
|
||||||
|
def test_get_sha256_safe_not_dict(self):
|
||||||
|
"""Test safe sha256 extraction from non-dict"""
|
||||||
|
assert CacheEntryValidator.get_sha256_safe(None) == ''
|
||||||
|
assert CacheEntryValidator.get_sha256_safe('string') == ''
|
||||||
|
|
||||||
|
def test_validate_with_all_optional_fields(self):
|
||||||
|
"""Test validation with all optional fields present"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'abc123',
|
||||||
|
'file_name': 'test.safetensors',
|
||||||
|
'model_name': 'Test Model',
|
||||||
|
'folder': 'test_folder',
|
||||||
|
'size': 1024,
|
||||||
|
'modified': 1234567890.0,
|
||||||
|
'tags': ['tag1', 'tag2'],
|
||||||
|
'preview_url': 'http://example.com/preview.jpg',
|
||||||
|
'base_model': 'SD1.5',
|
||||||
|
'from_civitai': True,
|
||||||
|
'favorite': True,
|
||||||
|
'exclude': False,
|
||||||
|
'db_checked': True,
|
||||||
|
'preview_nsfw_level': 1,
|
||||||
|
'notes': 'Test notes',
|
||||||
|
'usage_tips': 'Test tips',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.repaired is False
|
||||||
|
assert result.entry == entry
|
||||||
|
|
||||||
|
def test_validate_numeric_field_accepts_float_for_int(self):
|
||||||
|
"""Test that numeric fields accept float for int type"""
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'abc123',
|
||||||
|
'size': 1024.5, # Float for int field
|
||||||
|
'modified': 1234567890.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.repaired is False
|
||||||
364
tests/services/test_cache_health_monitor.py
Normal file
364
tests/services/test_cache_health_monitor.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for CacheHealthMonitor
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.cache_health_monitor import (
|
||||||
|
CacheHealthMonitor,
|
||||||
|
CacheHealthStatus,
|
||||||
|
HealthReport,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheHealthMonitor:
|
||||||
|
"""Tests for CacheHealthMonitor class"""
|
||||||
|
|
||||||
|
def test_check_health_all_valid_entries(self):
|
||||||
|
"""Test health check with 100% valid entries"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(100)
|
||||||
|
]
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.HEALTHY
|
||||||
|
assert report.total_entries == 100
|
||||||
|
assert report.valid_entries == 100
|
||||||
|
assert report.invalid_entries == 0
|
||||||
|
assert report.repaired_entries == 0
|
||||||
|
assert report.corruption_rate == 0.0
|
||||||
|
assert report.message == "Cache is healthy"
|
||||||
|
|
||||||
|
def test_check_health_degraded_cache(self):
|
||||||
|
"""Test health check with 1-5% invalid entries (degraded)"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
# Create 100 entries, 2 invalid (2%)
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(98)
|
||||||
|
]
|
||||||
|
# Add 2 invalid entries
|
||||||
|
entries.append({'file_path': '/models/invalid1.safetensors'}) # Missing sha256
|
||||||
|
entries.append({'file_path': '/models/invalid2.safetensors'}) # Missing sha256
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.DEGRADED
|
||||||
|
assert report.total_entries == 100
|
||||||
|
assert report.valid_entries == 98
|
||||||
|
assert report.invalid_entries == 2
|
||||||
|
assert report.corruption_rate == 0.02
|
||||||
|
# Message describes the issue without necessarily containing the word "degraded"
|
||||||
|
assert 'invalid entries' in report.message.lower()
|
||||||
|
|
||||||
|
def test_check_health_corrupted_cache(self):
|
||||||
|
"""Test health check with >5% invalid entries (corrupted)"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
# Create 100 entries, 10 invalid (10%)
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(90)
|
||||||
|
]
|
||||||
|
# Add 10 invalid entries
|
||||||
|
for i in range(10):
|
||||||
|
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.CORRUPTED
|
||||||
|
assert report.total_entries == 100
|
||||||
|
assert report.valid_entries == 90
|
||||||
|
assert report.invalid_entries == 10
|
||||||
|
assert report.corruption_rate == 0.10
|
||||||
|
assert 'corrupted' in report.message.lower()
|
||||||
|
|
||||||
|
def test_check_health_empty_cache(self):
|
||||||
|
"""Test health check with empty cache"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
report = monitor.check_health([], auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.HEALTHY
|
||||||
|
assert report.total_entries == 0
|
||||||
|
assert report.valid_entries == 0
|
||||||
|
assert report.invalid_entries == 0
|
||||||
|
assert report.corruption_rate == 0.0
|
||||||
|
assert report.message == "Cache is empty"
|
||||||
|
|
||||||
|
def test_check_health_single_invalid_entry(self):
|
||||||
|
"""Test health check with 1 invalid entry out of 1 (100% corruption)"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
entries = [{'file_path': '/models/invalid.safetensors'}]
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.CORRUPTED
|
||||||
|
assert report.total_entries == 1
|
||||||
|
assert report.valid_entries == 0
|
||||||
|
assert report.invalid_entries == 1
|
||||||
|
assert report.corruption_rate == 1.0
|
||||||
|
|
||||||
|
def test_check_health_boundary_degraded_threshold(self):
|
||||||
|
"""Test health check at degraded threshold (1%)"""
|
||||||
|
monitor = CacheHealthMonitor(degraded_threshold=0.01)
|
||||||
|
|
||||||
|
# 100 entries, 1 invalid (exactly 1%)
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(99)
|
||||||
|
]
|
||||||
|
entries.append({'file_path': '/models/invalid.safetensors'})
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.DEGRADED
|
||||||
|
assert report.corruption_rate == 0.01
|
||||||
|
|
||||||
|
def test_check_health_boundary_corrupted_threshold(self):
|
||||||
|
"""Test health check at corrupted threshold (5%)"""
|
||||||
|
monitor = CacheHealthMonitor(corrupted_threshold=0.05)
|
||||||
|
|
||||||
|
# 100 entries, 5 invalid (exactly 5%)
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(95)
|
||||||
|
]
|
||||||
|
for i in range(5):
|
||||||
|
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.CORRUPTED
|
||||||
|
assert report.corruption_rate == 0.05
|
||||||
|
|
||||||
|
def test_check_health_below_degraded_threshold(self):
|
||||||
|
"""Test health check below degraded threshold (0%)"""
|
||||||
|
monitor = CacheHealthMonitor(degraded_threshold=0.01)
|
||||||
|
|
||||||
|
# All entries valid
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(100)
|
||||||
|
]
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.HEALTHY
|
||||||
|
assert report.corruption_rate == 0.0
|
||||||
|
|
||||||
|
def test_check_health_auto_repair(self):
|
||||||
|
"""Test health check with auto_repair enabled"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
# 1 entry with all fields (won't be repaired), 1 entry with missing non-required fields (will be repaired)
|
||||||
|
complete_entry = {
|
||||||
|
'file_path': '/models/test1.safetensors',
|
||||||
|
'sha256': 'hash1',
|
||||||
|
'file_name': 'test1.safetensors',
|
||||||
|
'model_name': 'Model 1',
|
||||||
|
'folder': '',
|
||||||
|
'size': 0,
|
||||||
|
'modified': 0.0,
|
||||||
|
'tags': ['tag1'],
|
||||||
|
'preview_url': '',
|
||||||
|
'base_model': '',
|
||||||
|
'from_civitai': True,
|
||||||
|
'favorite': False,
|
||||||
|
'exclude': False,
|
||||||
|
'db_checked': False,
|
||||||
|
'preview_nsfw_level': 0,
|
||||||
|
'notes': '',
|
||||||
|
'usage_tips': '',
|
||||||
|
}
|
||||||
|
incomplete_entry = {
|
||||||
|
'file_path': '/models/test2.safetensors',
|
||||||
|
'sha256': 'hash2',
|
||||||
|
# Missing many optional fields (will be repaired)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = [complete_entry, incomplete_entry]
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=True)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.HEALTHY
|
||||||
|
assert report.total_entries == 2
|
||||||
|
assert report.valid_entries == 2
|
||||||
|
assert report.invalid_entries == 0
|
||||||
|
assert report.repaired_entries == 1
|
||||||
|
|
||||||
|
def test_should_notify_user_healthy(self):
|
||||||
|
"""Test should_notify_user for healthy cache"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.HEALTHY,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=100,
|
||||||
|
invalid_entries=0,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is healthy"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert monitor.should_notify_user(report) is False
|
||||||
|
|
||||||
|
def test_should_notify_user_degraded(self):
|
||||||
|
"""Test should_notify_user for degraded cache"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.DEGRADED,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=98,
|
||||||
|
invalid_entries=2,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is degraded"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert monitor.should_notify_user(report) is True
|
||||||
|
|
||||||
|
def test_should_notify_user_corrupted(self):
|
||||||
|
"""Test should_notify_user for corrupted cache"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.CORRUPTED,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=90,
|
||||||
|
invalid_entries=10,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert monitor.should_notify_user(report) is True
|
||||||
|
|
||||||
|
def test_get_notification_severity_degraded(self):
|
||||||
|
"""Test get_notification_severity for degraded cache"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.DEGRADED,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=98,
|
||||||
|
invalid_entries=2,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is degraded"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert monitor.get_notification_severity(report) == 'warning'
|
||||||
|
|
||||||
|
def test_get_notification_severity_corrupted(self):
|
||||||
|
"""Test get_notification_severity for corrupted cache"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.CORRUPTED,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=90,
|
||||||
|
invalid_entries=10,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert monitor.get_notification_severity(report) == 'error'
|
||||||
|
|
||||||
|
def test_report_to_dict(self):
|
||||||
|
"""Test HealthReport to_dict conversion"""
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.DEGRADED,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=98,
|
||||||
|
invalid_entries=2,
|
||||||
|
repaired_entries=1,
|
||||||
|
invalid_paths=['/path1', '/path2'],
|
||||||
|
message="Cache issues detected"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = report.to_dict()
|
||||||
|
|
||||||
|
assert result['status'] == 'degraded'
|
||||||
|
assert result['total_entries'] == 100
|
||||||
|
assert result['valid_entries'] == 98
|
||||||
|
assert result['invalid_entries'] == 2
|
||||||
|
assert result['repaired_entries'] == 1
|
||||||
|
assert result['corruption_rate'] == '2.0%'
|
||||||
|
assert len(result['invalid_paths']) == 2
|
||||||
|
assert result['message'] == "Cache issues detected"
|
||||||
|
|
||||||
|
def test_report_corruption_rate_zero_division(self):
|
||||||
|
"""Test corruption_rate calculation with zero entries"""
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.HEALTHY,
|
||||||
|
total_entries=0,
|
||||||
|
valid_entries=0,
|
||||||
|
invalid_entries=0,
|
||||||
|
repaired_entries=0,
|
||||||
|
message="Cache is empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert report.corruption_rate == 0.0
|
||||||
|
|
||||||
|
def test_check_health_collects_invalid_paths(self):
|
||||||
|
"""Test health check collects invalid entry paths"""
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': '/models/valid.safetensors',
|
||||||
|
'sha256': 'hash1',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': '/models/invalid1.safetensors',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': '/models/invalid2.safetensors',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
report = monitor.check_health(entries, auto_repair=False)
|
||||||
|
|
||||||
|
assert len(report.invalid_paths) == 2
|
||||||
|
assert '/models/invalid1.safetensors' in report.invalid_paths
|
||||||
|
assert '/models/invalid2.safetensors' in report.invalid_paths
|
||||||
|
|
||||||
|
def test_report_to_dict_limits_invalid_paths(self):
|
||||||
|
"""Test that to_dict limits invalid_paths to first 10"""
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.CORRUPTED,
|
||||||
|
total_entries=15,
|
||||||
|
valid_entries=0,
|
||||||
|
invalid_entries=15,
|
||||||
|
repaired_entries=0,
|
||||||
|
invalid_paths=[f'/path{i}' for i in range(15)],
|
||||||
|
message="Cache corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = report.to_dict()
|
||||||
|
|
||||||
|
assert len(result['invalid_paths']) == 10
|
||||||
|
assert result['invalid_paths'][0] == '/path0'
|
||||||
|
assert result['invalid_paths'][-1] == '/path9'
|
||||||
368
tests/services/test_check_pending_models.py
Normal file
368
tests/services/test_check_pending_models.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""Tests for the check_pending_models lightweight pre-check functionality."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.settings_manager import get_settings_manager
|
||||||
|
from py.utils import example_images_download_manager as download_module
|
||||||
|
|
||||||
|
|
||||||
|
class StubScanner:
|
||||||
|
"""Scanner double returning predetermined cache contents."""
|
||||||
|
|
||||||
|
def __init__(self, models: list[dict]) -> None:
|
||||||
|
self._cache = SimpleNamespace(raw_data=models)
|
||||||
|
|
||||||
|
async def get_cached_data(self):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_scanners(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
lora_scanner: StubScanner | None = None,
|
||||||
|
checkpoint_scanner: StubScanner | None = None,
|
||||||
|
embedding_scanner: StubScanner | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Patch ServiceRegistry to return stub scanners."""
|
||||||
|
|
||||||
|
async def _get_lora_scanner(cls):
|
||||||
|
return lora_scanner or StubScanner([])
|
||||||
|
|
||||||
|
async def _get_checkpoint_scanner(cls):
|
||||||
|
return checkpoint_scanner or StubScanner([])
|
||||||
|
|
||||||
|
async def _get_embedding_scanner(cls):
|
||||||
|
return embedding_scanner or StubScanner([])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ServiceRegistry,
|
||||||
|
"get_lora_scanner",
|
||||||
|
classmethod(_get_lora_scanner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ServiceRegistry,
|
||||||
|
"get_checkpoint_scanner",
|
||||||
|
classmethod(_get_checkpoint_scanner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ServiceRegistry,
|
||||||
|
"get_embedding_scanner",
|
||||||
|
classmethod(_get_embedding_scanner),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingWebSocketManager:
|
||||||
|
"""Collects broadcast payloads for assertions."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.payloads: list[dict] = []
|
||||||
|
|
||||||
|
async def broadcast(self, payload: dict) -> None:
|
||||||
|
self.payloads.append(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_returns_zero_when_all_processed(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models returns 0 pending when all models are processed."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Create processed models
|
||||||
|
processed_hashes = ["a" * 64, "b" * 64, "c" * 64]
|
||||||
|
models = [
|
||||||
|
{"sha256": h, "model_name": f"Model {i}"}
|
||||||
|
for i, h in enumerate(processed_hashes)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create progress file with all models processed
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text(
|
||||||
|
json.dumps({"processed_models": processed_hashes, "failed_models": []}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create model directories with files (simulating completed downloads)
|
||||||
|
for h in processed_hashes:
|
||||||
|
model_dir = tmp_path / h
|
||||||
|
model_dir.mkdir()
|
||||||
|
(model_dir / "image_0.png").write_text("data")
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["is_downloading"] is False
|
||||||
|
assert result["total_models"] == 3
|
||||||
|
assert result["pending_count"] == 0
|
||||||
|
assert result["processed_count"] == 3
|
||||||
|
assert result["needs_download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_finds_unprocessed_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models correctly identifies unprocessed models."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Create models - some processed, some not
|
||||||
|
processed_hash = "a" * 64
|
||||||
|
unprocessed_hash = "b" * 64
|
||||||
|
models = [
|
||||||
|
{"sha256": processed_hash, "model_name": "Processed Model"},
|
||||||
|
{"sha256": unprocessed_hash, "model_name": "Unprocessed Model"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create progress file with only one model processed
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text(
|
||||||
|
json.dumps({"processed_models": [processed_hash], "failed_models": []}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create directory only for processed model
|
||||||
|
processed_dir = tmp_path / processed_hash
|
||||||
|
processed_dir.mkdir()
|
||||||
|
(processed_dir / "image_0.png").write_text("data")
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 2
|
||||||
|
assert result["pending_count"] == 1
|
||||||
|
assert result["processed_count"] == 1
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_skips_models_without_hash(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that models without sha256 are not counted as pending."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Models - one with hash, one without
|
||||||
|
models = [
|
||||||
|
{"sha256": "a" * 64, "model_name": "Hashed Model"},
|
||||||
|
{"sha256": None, "model_name": "No Hash Model"},
|
||||||
|
{"model_name": "Missing Hash Model"}, # No sha256 key at all
|
||||||
|
]
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 3
|
||||||
|
assert result["pending_count"] == 1 # Only the one with hash
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_multiple_model_types(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models aggregates counts across multiple model types."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
lora_models = [
|
||||||
|
{"sha256": "a" * 64, "model_name": "Lora 1"},
|
||||||
|
{"sha256": "b" * 64, "model_name": "Lora 2"},
|
||||||
|
]
|
||||||
|
checkpoint_models = [
|
||||||
|
{"sha256": "c" * 64, "model_name": "Checkpoint 1"},
|
||||||
|
]
|
||||||
|
embedding_models = [
|
||||||
|
{"sha256": "d" * 64, "model_name": "Embedding 1"},
|
||||||
|
{"sha256": "e" * 64, "model_name": "Embedding 2"},
|
||||||
|
{"sha256": "f" * 64, "model_name": "Embedding 3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_patch_scanners(
|
||||||
|
monkeypatch,
|
||||||
|
lora_scanner=StubScanner(lora_models),
|
||||||
|
checkpoint_scanner=StubScanner(checkpoint_models),
|
||||||
|
embedding_scanner=StubScanner(embedding_models),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora", "checkpoint", "embedding"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 6 # 2 + 1 + 3
|
||||||
|
assert result["pending_count"] == 6 # All unprocessed
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_returns_error_when_download_in_progress(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models returns special response when download is running."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Simulate download in progress
|
||||||
|
manager._is_downloading = True
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["is_downloading"] is True
|
||||||
|
assert result["needs_download"] is False
|
||||||
|
assert result["pending_count"] == 0
|
||||||
|
assert "already in progress" in result["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_empty_library(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models handles empty model library."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner([]))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 0
|
||||||
|
assert result["pending_count"] == 0
|
||||||
|
assert result["processed_count"] == 0
|
||||||
|
assert result["needs_download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_reads_failed_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models correctly reports failed model count."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||||
|
|
||||||
|
# Create progress file with failed models
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text(
|
||||||
|
json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["failed_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_missing_progress_file(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models works correctly when no progress file exists."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
models = [
|
||||||
|
{"sha256": "a" * 64, "model_name": "Model 1"},
|
||||||
|
{"sha256": "b" * 64, "model_name": "Model 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
# No progress file created
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 2
|
||||||
|
assert result["pending_count"] == 2 # All pending since no progress
|
||||||
|
assert result["processed_count"] == 0
|
||||||
|
assert result["failed_count"] == 0
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_corrupted_progress_file(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models handles corrupted progress file gracefully."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||||
|
|
||||||
|
# Create corrupted progress file
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text("not valid json", encoding="utf-8")
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
# Should still succeed, treating all as unprocessed
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 1
|
||||||
|
assert result["pending_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings_manager():
|
||||||
|
return get_settings_manager()
|
||||||
167
tests/services/test_model_scanner_cache_validation.py
Normal file
167
tests/services/test_model_scanner_cache_validation.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for cache validation in ModelScanner
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from py.services.model_scanner import ModelScanner
|
||||||
|
from py.services.cache_entry_validator import CacheEntryValidator
|
||||||
|
from py.services.cache_health_monitor import CacheHealthMonitor, CacheHealthStatus
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_scanner_validates_cache_entries(tmp_path_factory):
|
||||||
|
"""Test that ModelScanner validates cache entries during initialization"""
|
||||||
|
# Create temporary test data
|
||||||
|
tmp_dir = tmp_path_factory.mktemp("test_loras")
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
test_file = tmp_dir / "test_model.safetensors"
|
||||||
|
test_file.write_bytes(b"fake model data" * 100)
|
||||||
|
|
||||||
|
# Mock model scanner (we can't easily instantiate a full scanner in tests)
|
||||||
|
# Instead, test the validation logic directly
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': str(test_file),
|
||||||
|
'sha256': 'abc123def456',
|
||||||
|
'file_name': 'test_model.safetensors',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'file_path': str(tmp_dir / 'invalid.safetensors'),
|
||||||
|
# Missing sha256 - invalid
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=True)
|
||||||
|
|
||||||
|
assert len(valid) == 1
|
||||||
|
assert len(invalid) == 1
|
||||||
|
assert valid[0]['sha256'] == 'abc123def456'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_scanner_detects_degraded_cache():
|
||||||
|
"""Test that ModelScanner detects degraded cache health"""
|
||||||
|
# Create 100 entries with 2% corruption
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(98)
|
||||||
|
]
|
||||||
|
# Add 2 invalid entries
|
||||||
|
entries.append({'file_path': '/models/invalid1.safetensors'})
|
||||||
|
entries.append({'file_path': '/models/invalid2.safetensors'})
|
||||||
|
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
report = monitor.check_health(entries, auto_repair=True)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.DEGRADED
|
||||||
|
assert report.invalid_entries == 2
|
||||||
|
assert report.valid_entries == 98
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_scanner_detects_corrupted_cache():
|
||||||
|
"""Test that ModelScanner detects corrupted cache health"""
|
||||||
|
# Create 100 entries with 10% corruption
|
||||||
|
entries = [
|
||||||
|
{
|
||||||
|
'file_path': f'/models/test{i}.safetensors',
|
||||||
|
'sha256': f'hash{i}',
|
||||||
|
}
|
||||||
|
for i in range(90)
|
||||||
|
]
|
||||||
|
# Add 10 invalid entries
|
||||||
|
for i in range(10):
|
||||||
|
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
|
||||||
|
|
||||||
|
monitor = CacheHealthMonitor()
|
||||||
|
report = monitor.check_health(entries, auto_repair=True)
|
||||||
|
|
||||||
|
assert report.status == CacheHealthStatus.CORRUPTED
|
||||||
|
assert report.invalid_entries == 10
|
||||||
|
assert report.valid_entries == 90
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_scanner_removes_invalid_from_hash_index():
|
||||||
|
"""Test that ModelScanner removes invalid entries from hash index"""
|
||||||
|
from py.services.model_hash_index import ModelHashIndex
|
||||||
|
|
||||||
|
# Create a hash index with some entries
|
||||||
|
hash_index = ModelHashIndex()
|
||||||
|
valid_entry = {
|
||||||
|
'file_path': '/models/valid.safetensors',
|
||||||
|
'sha256': 'abc123',
|
||||||
|
}
|
||||||
|
invalid_entry = {
|
||||||
|
'file_path': '/models/invalid.safetensors',
|
||||||
|
'sha256': '', # Empty sha256
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add entries to hash index
|
||||||
|
hash_index.add_entry(valid_entry['sha256'], valid_entry['file_path'])
|
||||||
|
hash_index.add_entry(invalid_entry['sha256'], invalid_entry['file_path'])
|
||||||
|
|
||||||
|
# Verify both entries are in the index (using get_hash method)
|
||||||
|
assert hash_index.get_hash(valid_entry['file_path']) == valid_entry['sha256']
|
||||||
|
# Invalid entry won't be added due to empty sha256
|
||||||
|
assert hash_index.get_hash(invalid_entry['file_path']) is None
|
||||||
|
|
||||||
|
# Simulate removing invalid entry (it's not actually there, but let's test the method)
|
||||||
|
hash_index.remove_by_path(
|
||||||
|
CacheEntryValidator.get_file_path_safe(invalid_entry),
|
||||||
|
CacheEntryValidator.get_sha256_safe(invalid_entry)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify valid entry remains
|
||||||
|
assert hash_index.get_hash(valid_entry['file_path']) == valid_entry['sha256']
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_entry_validator_handles_various_field_types():
|
||||||
|
"""Test that validator handles various field types correctly"""
|
||||||
|
# Test with different field types
|
||||||
|
entry = {
|
||||||
|
'file_path': '/models/test.safetensors',
|
||||||
|
'sha256': 'abc123',
|
||||||
|
'size': 1024, # int
|
||||||
|
'modified': 1234567890.0, # float
|
||||||
|
'favorite': True, # bool
|
||||||
|
'tags': ['tag1', 'tag2'], # list
|
||||||
|
'exclude': False, # bool
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||||
|
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert result.repaired is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_health_report_serialization():
|
||||||
|
"""Test that HealthReport can be serialized to dict"""
|
||||||
|
from py.services.cache_health_monitor import HealthReport
|
||||||
|
|
||||||
|
report = HealthReport(
|
||||||
|
status=CacheHealthStatus.DEGRADED,
|
||||||
|
total_entries=100,
|
||||||
|
valid_entries=98,
|
||||||
|
invalid_entries=2,
|
||||||
|
repaired_entries=1,
|
||||||
|
invalid_paths=['/path1', '/path2'],
|
||||||
|
message="Cache issues detected"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = report.to_dict()
|
||||||
|
|
||||||
|
assert result['status'] == 'degraded'
|
||||||
|
assert result['total_entries'] == 100
|
||||||
|
assert result['valid_entries'] == 98
|
||||||
|
assert result['invalid_entries'] == 2
|
||||||
|
assert result['repaired_entries'] == 1
|
||||||
|
assert result['corruption_rate'] == '2.0%'
|
||||||
|
assert len(result['invalid_paths']) == 2
|
||||||
|
assert result['message'] == "Cache issues detected"
|
||||||
@@ -242,6 +242,148 @@ async def test_bulk_metadata_refresh_reports_errors() -> None:
|
|||||||
assert progress.events[-1]["error"] == "boom"
|
assert progress.events[-1]["error"] == "boom"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_bulk_metadata_refresh_skips_confirmed_not_found_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Models marked as from_civitai=False and civitai_deleted=True should be skipped."""
|
||||||
|
scanner = MockScanner()
|
||||||
|
scanner._cache.raw_data = [
|
||||||
|
{
|
||||||
|
"file_path": "model1.safetensors",
|
||||||
|
"sha256": "hash1",
|
||||||
|
"from_civitai": False,
|
||||||
|
"civitai_deleted": True,
|
||||||
|
"model_name": "NotOnCivitAI",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_path": "model2.safetensors",
|
||||||
|
"sha256": "hash2",
|
||||||
|
"from_civitai": True,
|
||||||
|
"model_name": "OnCivitAI",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
service = MockModelService(scanner)
|
||||||
|
metadata_sync = StubMetadataSync()
|
||||||
|
settings = StubSettings(enable_metadata_archive_db=False)
|
||||||
|
progress = ProgressCollector()
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
# Preserve the original data (simulating no metadata file on disk)
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||||
|
|
||||||
|
use_case = BulkMetadataRefreshUseCase(
|
||||||
|
service=service,
|
||||||
|
metadata_sync=metadata_sync,
|
||||||
|
settings_service=settings,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
# Only model2 should be processed (model1 is skipped)
|
||||||
|
assert result["processed"] == 1
|
||||||
|
assert result["updated"] == 1
|
||||||
|
assert len(metadata_sync.calls) == 1
|
||||||
|
assert metadata_sync.calls[0]["file_path"] == "model2.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_bulk_metadata_refresh_skips_when_archive_checked(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Models with db_checked=True should be skipped even if archive DB is enabled."""
|
||||||
|
scanner = MockScanner()
|
||||||
|
scanner._cache.raw_data = [
|
||||||
|
{
|
||||||
|
"file_path": "model1.safetensors",
|
||||||
|
"sha256": "hash1",
|
||||||
|
"from_civitai": False,
|
||||||
|
"civitai_deleted": True,
|
||||||
|
"db_checked": True,
|
||||||
|
"model_name": "ArchiveChecked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_path": "model2.safetensors",
|
||||||
|
"sha256": "hash2",
|
||||||
|
"from_civitai": False,
|
||||||
|
"civitai_deleted": True,
|
||||||
|
"db_checked": False,
|
||||||
|
"model_name": "ArchiveNotChecked",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
service = MockModelService(scanner)
|
||||||
|
metadata_sync = StubMetadataSync()
|
||||||
|
settings = StubSettings(enable_metadata_archive_db=True)
|
||||||
|
progress = ProgressCollector()
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||||
|
|
||||||
|
use_case = BulkMetadataRefreshUseCase(
|
||||||
|
service=service,
|
||||||
|
metadata_sync=metadata_sync,
|
||||||
|
settings_service=settings,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
# Only model2 should be processed (model1 has db_checked=True)
|
||||||
|
assert result["processed"] == 1
|
||||||
|
assert result["updated"] == 1
|
||||||
|
assert len(metadata_sync.calls) == 1
|
||||||
|
assert metadata_sync.calls[0]["file_path"] == "model2.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_bulk_metadata_refresh_processes_never_fetched_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Models that have never been fetched (from_civitai=None) should be processed."""
|
||||||
|
scanner = MockScanner()
|
||||||
|
scanner._cache.raw_data = [
|
||||||
|
{
|
||||||
|
"file_path": "model1.safetensors",
|
||||||
|
"sha256": "hash1",
|
||||||
|
"from_civitai": None,
|
||||||
|
"model_name": "NeverFetched",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_path": "model2.safetensors",
|
||||||
|
"sha256": "hash2",
|
||||||
|
"model_name": "NoFromCivitaiField",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
service = MockModelService(scanner)
|
||||||
|
metadata_sync = StubMetadataSync()
|
||||||
|
settings = StubSettings(enable_metadata_archive_db=False)
|
||||||
|
progress = ProgressCollector()
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||||
|
|
||||||
|
use_case = BulkMetadataRefreshUseCase(
|
||||||
|
service=service,
|
||||||
|
metadata_sync=metadata_sync,
|
||||||
|
settings_service=settings,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
# Both models should be processed
|
||||||
|
assert result["processed"] == 2
|
||||||
|
assert result["updated"] == 2
|
||||||
|
assert len(metadata_sync.calls) == 2
|
||||||
|
|
||||||
|
|
||||||
async def test_download_model_use_case_raises_validation_error() -> None:
|
async def test_download_model_use_case_raises_validation_error() -> None:
|
||||||
coordinator = StubDownloadCoordinator(error="validation")
|
coordinator = StubDownloadCoordinator(error="validation")
|
||||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||||
|
|||||||
@@ -75,6 +75,31 @@ def test_get_file_extension_defaults_to_jpg() -> None:
|
|||||||
assert ext == ".jpg"
|
assert ext == ".jpg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_extension_from_media_type_hint_video() -> None:
|
||||||
|
"""Test that media_type_hint='video' returns .mp4 when other methods fail"""
|
||||||
|
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||||
|
b"", {}, "https://c.genur.art/536be3c9-e506-4365-b078-bfbc5df9ceec", "video"
|
||||||
|
)
|
||||||
|
assert ext == ".mp4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_extension_from_media_type_hint_image() -> None:
|
||||||
|
"""Test that media_type_hint='image' falls back to .jpg"""
|
||||||
|
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||||
|
b"", {}, "https://example.com/no-extension", "image"
|
||||||
|
)
|
||||||
|
assert ext == ".jpg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_extension_media_type_hint_low_priority() -> None:
|
||||||
|
"""Test that media_type_hint is only used as last resort (after URL extension)"""
|
||||||
|
# URL has extension, should use that instead of media_type_hint
|
||||||
|
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||||
|
b"", {}, "https://example.com/video.mp4", "image"
|
||||||
|
)
|
||||||
|
assert ext == ".mp4"
|
||||||
|
|
||||||
|
|
||||||
class StubScanner:
|
class StubScanner:
|
||||||
def __init__(self, models: list[Dict[str, Any]]) -> None:
|
def __init__(self, models: list[Dict[str, Any]]) -> None:
|
||||||
self._cache = SimpleNamespace(raw_data=models)
|
self._cache = SimpleNamespace(raw_data=models)
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ export default defineConfig({
|
|||||||
globals: true,
|
globals: true,
|
||||||
setupFiles: ['tests/frontend/setup.js'],
|
setupFiles: ['tests/frontend/setup.js'],
|
||||||
include: [
|
include: [
|
||||||
'tests/frontend/**/*.test.js'
|
'tests/frontend/**/*.test.js',
|
||||||
|
'tests/frontend/**/*.test.ts'
|
||||||
],
|
],
|
||||||
coverage: {
|
coverage: {
|
||||||
enabled: process.env.VITEST_COVERAGE === 'true',
|
enabled: process.env.VITEST_COVERAGE === 'true',
|
||||||
|
|||||||
1865
vue-widgets/package-lock.json
generated
1865
vue-widgets/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,13 @@
|
|||||||
"@comfyorg/comfyui-frontend-types": "^1.35.4",
|
"@comfyorg/comfyui-frontend-types": "^1.35.4",
|
||||||
"@types/node": "^22.10.1",
|
"@types/node": "^22.10.1",
|
||||||
"@vitejs/plugin-vue": "^5.2.3",
|
"@vitejs/plugin-vue": "^5.2.3",
|
||||||
|
"@vitest/coverage-v8": "^3.2.4",
|
||||||
|
"@vue/test-utils": "^2.4.6",
|
||||||
|
"jsdom": "^26.0.0",
|
||||||
"typescript": "^5.7.2",
|
"typescript": "^5.7.2",
|
||||||
"vite": "^6.3.5",
|
"vite": "^6.3.5",
|
||||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||||
|
"vitest": "^3.0.0",
|
||||||
"vue-tsc": "^2.1.10"
|
"vue-tsc": "^2.1.10"
|
||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
@@ -24,6 +28,9 @@
|
|||||||
"typecheck": "vue-tsc --noEmit",
|
"typecheck": "vue-tsc --noEmit",
|
||||||
"clean": "rm -rf ../web/comfyui/vue-widgets",
|
"clean": "rm -rf ../web/comfyui/vue-widgets",
|
||||||
"rebuild": "npm run clean && npm run build",
|
"rebuild": "npm run clean && npm run build",
|
||||||
"prepare": "npm run build"
|
"prepare": "npm run build",
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest",
|
||||||
|
"test:coverage": "vitest run --coverage"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,11 +10,28 @@
|
|||||||
:use-custom-clip-range="state.useCustomClipRange.value"
|
:use-custom-clip-range="state.useCustomClipRange.value"
|
||||||
:is-clip-strength-disabled="state.isClipStrengthDisabled.value"
|
:is-clip-strength-disabled="state.isClipStrengthDisabled.value"
|
||||||
:is-loading="state.isLoading.value"
|
:is-loading="state.isLoading.value"
|
||||||
|
:repeat-count="state.repeatCount.value"
|
||||||
|
:repeat-used="state.displayRepeatUsed.value"
|
||||||
|
:is-paused="state.isPaused.value"
|
||||||
|
:is-pause-disabled="hasQueuedPrompts"
|
||||||
|
:is-workflow-executing="state.isWorkflowExecuting.value"
|
||||||
|
:executing-repeat-step="state.executingRepeatStep.value"
|
||||||
@update:current-index="handleIndexUpdate"
|
@update:current-index="handleIndexUpdate"
|
||||||
@update:model-strength="state.modelStrength.value = $event"
|
@update:model-strength="state.modelStrength.value = $event"
|
||||||
@update:clip-strength="state.clipStrength.value = $event"
|
@update:clip-strength="state.clipStrength.value = $event"
|
||||||
@update:use-custom-clip-range="handleUseCustomClipRangeChange"
|
@update:use-custom-clip-range="handleUseCustomClipRangeChange"
|
||||||
@refresh="handleRefresh"
|
@update:repeat-count="handleRepeatCountChange"
|
||||||
|
@toggle-pause="handleTogglePause"
|
||||||
|
@reset-index="handleResetIndex"
|
||||||
|
@open-lora-selector="isModalOpen = true"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<LoraListModal
|
||||||
|
:visible="isModalOpen"
|
||||||
|
:lora-list="cachedLoraList"
|
||||||
|
:current-index="state.currentIndex.value"
|
||||||
|
@close="isModalOpen = false"
|
||||||
|
@select="handleModalSelect"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@@ -22,8 +39,9 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { onMounted, ref } from 'vue'
|
import { onMounted, ref } from 'vue'
|
||||||
import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue'
|
import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue'
|
||||||
|
import LoraListModal from './lora-cycler/LoraListModal.vue'
|
||||||
import { useLoraCyclerState } from '../composables/useLoraCyclerState'
|
import { useLoraCyclerState } from '../composables/useLoraCyclerState'
|
||||||
import type { ComponentWidget, CyclerConfig, LoraPoolConfig } from '../composables/types'
|
import type { ComponentWidget, CyclerConfig, LoraPoolConfig, LoraItem } from '../composables/types'
|
||||||
|
|
||||||
type CyclerWidget = ComponentWidget<CyclerConfig>
|
type CyclerWidget = ComponentWidget<CyclerConfig>
|
||||||
|
|
||||||
@@ -31,6 +49,7 @@ type CyclerWidget = ComponentWidget<CyclerConfig>
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
widget: CyclerWidget
|
widget: CyclerWidget
|
||||||
node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any }
|
node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any }
|
||||||
|
api?: any // ComfyUI API for execution events
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
// State management
|
// State management
|
||||||
@@ -39,12 +58,50 @@ const state = useLoraCyclerState(props.widget)
|
|||||||
// Symbol to track if the widget has been executed at least once
|
// Symbol to track if the widget has been executed at least once
|
||||||
const HAS_EXECUTED = Symbol('HAS_EXECUTED')
|
const HAS_EXECUTED = Symbol('HAS_EXECUTED')
|
||||||
|
|
||||||
|
// Execution context queue for batch queue synchronization
|
||||||
|
// In batch queue mode, all beforeQueued calls happen BEFORE any onExecuted calls,
|
||||||
|
// so we need to snapshot the state at queue time and replay it during execution
|
||||||
|
interface ExecutionContext {
|
||||||
|
isPaused: boolean
|
||||||
|
repeatUsed: number
|
||||||
|
repeatCount: number
|
||||||
|
shouldAdvanceDisplay: boolean
|
||||||
|
displayRepeatUsed: number // Value to show in UI after completion
|
||||||
|
}
|
||||||
|
const executionQueue: ExecutionContext[] = []
|
||||||
|
|
||||||
|
// Reactive flag to track if there are queued prompts (for disabling pause button)
|
||||||
|
const hasQueuedPrompts = ref(false)
|
||||||
|
|
||||||
|
// Track pending executions for batch queue support (deferred UI updates)
|
||||||
|
// Uses FIFO order since executions are processed in the order they were queued
|
||||||
|
interface PendingExecution {
|
||||||
|
repeatUsed: number
|
||||||
|
repeatCount: number
|
||||||
|
shouldAdvanceDisplay: boolean
|
||||||
|
displayRepeatUsed: number // Value to show in UI after completion
|
||||||
|
output?: {
|
||||||
|
nextIndex: number
|
||||||
|
nextLoraName: string
|
||||||
|
nextLoraFilename: string
|
||||||
|
currentLoraName: string
|
||||||
|
currentLoraFilename: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const pendingExecutions: PendingExecution[] = []
|
||||||
|
|
||||||
// Track last known pool config hash
|
// Track last known pool config hash
|
||||||
const lastPoolConfigHash = ref('')
|
const lastPoolConfigHash = ref('')
|
||||||
|
|
||||||
// Track if component is mounted
|
// Track if component is mounted
|
||||||
const isMounted = ref(false)
|
const isMounted = ref(false)
|
||||||
|
|
||||||
|
// Modal state
|
||||||
|
const isModalOpen = ref(false)
|
||||||
|
|
||||||
|
// Cache for LoRA list (used by modal)
|
||||||
|
const cachedLoraList = ref<LoraItem[]>([])
|
||||||
|
|
||||||
// Get pool config from connected node
|
// Get pool config from connected node
|
||||||
const getPoolConfig = (): LoraPoolConfig | null => {
|
const getPoolConfig = (): LoraPoolConfig | null => {
|
||||||
// Check if getPoolConfig method exists on node (added by main.ts)
|
// Check if getPoolConfig method exists on node (added by main.ts)
|
||||||
@@ -54,27 +111,47 @@ const getPoolConfig = (): LoraPoolConfig | null => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update display from LoRA list and index
|
||||||
|
const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
|
||||||
|
if (loraList.length > 0 && index > 0 && index <= loraList.length) {
|
||||||
|
const currentLora = loraList[index - 1]
|
||||||
|
if (currentLora) {
|
||||||
|
state.currentLoraName.value = currentLora.file_name
|
||||||
|
state.currentLoraFilename.value = currentLora.file_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle index update from user
|
// Handle index update from user
|
||||||
const handleIndexUpdate = async (newIndex: number) => {
|
const handleIndexUpdate = async (newIndex: number) => {
|
||||||
|
// Reset execution state when user manually changes index
|
||||||
|
// This ensures the next execution starts from the user-set index
|
||||||
|
;(props.widget as any)[HAS_EXECUTED] = false
|
||||||
|
state.executionIndex.value = null
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
// Clear execution queue since user is manually changing state
|
||||||
|
executionQueue.length = 0
|
||||||
|
hasQueuedPrompts.value = false
|
||||||
|
|
||||||
state.setIndex(newIndex)
|
state.setIndex(newIndex)
|
||||||
|
|
||||||
// Refresh list to update current LoRA display
|
// Refresh list to update current LoRA display
|
||||||
try {
|
try {
|
||||||
const poolConfig = getPoolConfig()
|
const poolConfig = getPoolConfig()
|
||||||
const loraList = await state.fetchCyclerList(poolConfig)
|
const loraList = await state.fetchCyclerList(poolConfig)
|
||||||
|
cachedLoraList.value = loraList
|
||||||
if (loraList.length > 0 && newIndex > 0 && newIndex <= loraList.length) {
|
updateDisplayFromLoraList(loraList, newIndex)
|
||||||
const currentLora = loraList[newIndex - 1]
|
|
||||||
if (currentLora) {
|
|
||||||
state.currentLoraName.value = currentLora.file_name
|
|
||||||
state.currentLoraFilename.value = currentLora.file_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[LoraCyclerWidget] Error updating index:', error)
|
console.error('[LoraCyclerWidget] Error updating index:', error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle LoRA selection from modal
|
||||||
|
const handleModalSelect = (index: number) => {
|
||||||
|
handleIndexUpdate(index)
|
||||||
|
}
|
||||||
|
|
||||||
// Handle use custom clip range toggle
|
// Handle use custom clip range toggle
|
||||||
const handleUseCustomClipRangeChange = (newValue: boolean) => {
|
const handleUseCustomClipRangeChange = (newValue: boolean) => {
|
||||||
state.useCustomClipRange.value = newValue
|
state.useCustomClipRange.value = newValue
|
||||||
@@ -84,13 +161,41 @@ const handleUseCustomClipRangeChange = (newValue: boolean) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle refresh button click
|
// Handle repeat count change
|
||||||
const handleRefresh = async () => {
|
const handleRepeatCountChange = (newValue: number) => {
|
||||||
|
state.repeatCount.value = newValue
|
||||||
|
// Reset repeatUsed when changing repeat count
|
||||||
|
state.repeatUsed.value = 0
|
||||||
|
state.displayRepeatUsed.value = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pause toggle
|
||||||
|
const handleTogglePause = () => {
|
||||||
|
state.togglePause()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle reset index
|
||||||
|
const handleResetIndex = async () => {
|
||||||
|
// Reset execution state
|
||||||
|
;(props.widget as any)[HAS_EXECUTED] = false
|
||||||
|
state.executionIndex.value = null
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
// Clear execution queue since user is resetting state
|
||||||
|
executionQueue.length = 0
|
||||||
|
hasQueuedPrompts.value = false
|
||||||
|
|
||||||
|
// Reset index and repeat state
|
||||||
|
state.resetIndex()
|
||||||
|
|
||||||
|
// Refresh list to update current LoRA display
|
||||||
try {
|
try {
|
||||||
const poolConfig = getPoolConfig()
|
const poolConfig = getPoolConfig()
|
||||||
await state.refreshList(poolConfig)
|
const loraList = await state.fetchCyclerList(poolConfig)
|
||||||
|
cachedLoraList.value = loraList
|
||||||
|
updateDisplayFromLoraList(loraList, 1)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[LoraCyclerWidget] Error refreshing:', error)
|
console.error('[LoraCyclerWidget] Error resetting index:', error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,6 +211,9 @@ const checkPoolConfigChanges = async () => {
|
|||||||
lastPoolConfigHash.value = newHash
|
lastPoolConfigHash.value = newHash
|
||||||
try {
|
try {
|
||||||
await state.refreshList(poolConfig)
|
await state.refreshList(poolConfig)
|
||||||
|
// Update cached list when pool config changes
|
||||||
|
const loraList = await state.fetchCyclerList(poolConfig)
|
||||||
|
cachedLoraList.value = loraList
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[LoraCyclerWidget] Error on pool config change:', error)
|
console.error('[LoraCyclerWidget] Error on pool config change:', error)
|
||||||
}
|
}
|
||||||
@@ -129,17 +237,68 @@ onMounted(async () => {
|
|||||||
|
|
||||||
// Add beforeQueued hook to handle index shifting for batch queue synchronization
|
// Add beforeQueued hook to handle index shifting for batch queue synchronization
|
||||||
// This ensures each execution uses a different LoRA in the cycle
|
// This ensures each execution uses a different LoRA in the cycle
|
||||||
|
// Now with support for repeat count and pause features
|
||||||
|
//
|
||||||
|
// IMPORTANT: In batch queue mode, ALL beforeQueued calls happen BEFORE any execution.
|
||||||
|
// We push an "execution context" snapshot to a queue so that onExecuted can use the
|
||||||
|
// correct state values that were captured at queue time (not the live state).
|
||||||
;(props.widget as any).beforeQueued = () => {
|
;(props.widget as any).beforeQueued = () => {
|
||||||
|
if (state.isPaused.value) {
|
||||||
|
// When paused: use current index, don't advance, don't count toward repeat limit
|
||||||
|
// Push context indicating this execution should NOT advance display
|
||||||
|
executionQueue.push({
|
||||||
|
isPaused: true,
|
||||||
|
repeatUsed: state.repeatUsed.value,
|
||||||
|
repeatCount: state.repeatCount.value,
|
||||||
|
shouldAdvanceDisplay: false,
|
||||||
|
displayRepeatUsed: state.displayRepeatUsed.value // Keep current display value when paused
|
||||||
|
})
|
||||||
|
hasQueuedPrompts.value = true
|
||||||
|
// CRITICAL: Clear execution_index when paused to force backend to use current_index
|
||||||
|
// This ensures paused executions use the same LoRA regardless of any
|
||||||
|
// execution_index set by previous non-paused beforeQueued calls
|
||||||
|
const pausedConfig = state.buildConfig()
|
||||||
|
pausedConfig.execution_index = null
|
||||||
|
props.widget.value = pausedConfig
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if ((props.widget as any)[HAS_EXECUTED]) {
|
if ((props.widget as any)[HAS_EXECUTED]) {
|
||||||
// After first execution: shift indices (previous next_index becomes execution_index)
|
// After first execution: check repeat logic
|
||||||
state.generateNextIndex()
|
if (state.repeatUsed.value < state.repeatCount.value) {
|
||||||
|
// Still repeating: increment repeatUsed, use same index
|
||||||
|
state.repeatUsed.value++
|
||||||
|
} else {
|
||||||
|
// Repeat complete: reset repeatUsed to 1, advance to next index
|
||||||
|
state.repeatUsed.value = 1
|
||||||
|
state.generateNextIndex()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// First execution: just initialize next_index (execution_index stays null)
|
// First execution: initialize
|
||||||
// This means first execution uses current_index from widget
|
state.repeatUsed.value = 1
|
||||||
state.initializeNextIndex()
|
state.initializeNextIndex()
|
||||||
;(props.widget as any)[HAS_EXECUTED] = true
|
;(props.widget as any)[HAS_EXECUTED] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine if this execution should advance the display
|
||||||
|
// (only when repeat cycle is complete for this queued item)
|
||||||
|
const shouldAdvanceDisplay = state.repeatUsed.value >= state.repeatCount.value
|
||||||
|
|
||||||
|
// Calculate the display value to show after this execution completes
|
||||||
|
// When advancing to a new LoRA: reset to 0 (fresh start for new LoRA)
|
||||||
|
// When repeating same LoRA: show current repeat step
|
||||||
|
const displayRepeatUsed = shouldAdvanceDisplay ? 0 : state.repeatUsed.value
|
||||||
|
|
||||||
|
// Push execution context snapshot to queue
|
||||||
|
executionQueue.push({
|
||||||
|
isPaused: false,
|
||||||
|
repeatUsed: state.repeatUsed.value,
|
||||||
|
repeatCount: state.repeatCount.value,
|
||||||
|
shouldAdvanceDisplay,
|
||||||
|
displayRepeatUsed
|
||||||
|
})
|
||||||
|
hasQueuedPrompts.value = true
|
||||||
|
|
||||||
// Update the widget value so the indices are included in the serialized config
|
// Update the widget value so the indices are included in the serialized config
|
||||||
props.widget.value = state.buildConfig()
|
props.widget.value = state.buildConfig()
|
||||||
}
|
}
|
||||||
@@ -152,40 +311,71 @@ onMounted(async () => {
|
|||||||
const poolConfig = getPoolConfig()
|
const poolConfig = getPoolConfig()
|
||||||
lastPoolConfigHash.value = state.hashPoolConfig(poolConfig)
|
lastPoolConfigHash.value = state.hashPoolConfig(poolConfig)
|
||||||
await state.refreshList(poolConfig)
|
await state.refreshList(poolConfig)
|
||||||
|
// Cache the initial LoRA list for modal
|
||||||
|
const loraList = await state.fetchCyclerList(poolConfig)
|
||||||
|
cachedLoraList.value = loraList
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[LoraCyclerWidget] Error on initial load:', error)
|
console.error('[LoraCyclerWidget] Error on initial load:', error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override onExecuted to handle backend UI updates
|
// Override onExecuted to handle backend UI updates
|
||||||
|
// This defers the UI update until workflow completes (via API events)
|
||||||
const originalOnExecuted = (props.node as any).onExecuted?.bind(props.node)
|
const originalOnExecuted = (props.node as any).onExecuted?.bind(props.node)
|
||||||
|
|
||||||
;(props.node as any).onExecuted = function(output: any) {
|
;(props.node as any).onExecuted = function(output: any) {
|
||||||
console.log("[LoraCyclerWidget] Node executed with output:", output)
|
console.log("[LoraCyclerWidget] Node executed with output:", output)
|
||||||
|
|
||||||
// Update state from backend response (values are wrapped in arrays)
|
// Pop execution context from queue (FIFO order)
|
||||||
if (output?.next_index !== undefined) {
|
const context = executionQueue.shift()
|
||||||
const val = Array.isArray(output.next_index) ? output.next_index[0] : output.next_index
|
hasQueuedPrompts.value = executionQueue.length > 0
|
||||||
state.currentIndex.value = val
|
|
||||||
}
|
// Determine if we should advance the display index
|
||||||
|
const shouldAdvanceDisplay = context
|
||||||
|
? context.shouldAdvanceDisplay
|
||||||
|
: (!state.isPaused.value && state.repeatUsed.value >= state.repeatCount.value)
|
||||||
|
|
||||||
|
// Extract output values
|
||||||
|
const nextIndex = output?.next_index !== undefined
|
||||||
|
? (Array.isArray(output.next_index) ? output.next_index[0] : output.next_index)
|
||||||
|
: state.currentIndex.value
|
||||||
|
const nextLoraName = output?.next_lora_name !== undefined
|
||||||
|
? (Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name)
|
||||||
|
: ''
|
||||||
|
const nextLoraFilename = output?.next_lora_filename !== undefined
|
||||||
|
? (Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename)
|
||||||
|
: ''
|
||||||
|
const currentLoraName = output?.current_lora_name !== undefined
|
||||||
|
? (Array.isArray(output.current_lora_name) ? output.current_lora_name[0] : output.current_lora_name)
|
||||||
|
: ''
|
||||||
|
const currentLoraFilename = output?.current_lora_filename !== undefined
|
||||||
|
? (Array.isArray(output.current_lora_filename) ? output.current_lora_filename[0] : output.current_lora_filename)
|
||||||
|
: ''
|
||||||
|
|
||||||
|
// Update total count immediately (doesn't need to wait for workflow completion)
|
||||||
if (output?.total_count !== undefined) {
|
if (output?.total_count !== undefined) {
|
||||||
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
|
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
|
||||||
state.totalCount.value = val
|
state.totalCount.value = val
|
||||||
}
|
}
|
||||||
if (output?.current_lora_name !== undefined) {
|
|
||||||
const val = Array.isArray(output.current_lora_name) ? output.current_lora_name[0] : output.current_lora_name
|
// Store pending update (will be applied on workflow completion)
|
||||||
state.currentLoraName.value = val
|
if (context) {
|
||||||
}
|
pendingExecutions.push({
|
||||||
if (output?.current_lora_filename !== undefined) {
|
repeatUsed: context.repeatUsed,
|
||||||
const val = Array.isArray(output.current_lora_filename) ? output.current_lora_filename[0] : output.current_lora_filename
|
repeatCount: context.repeatCount,
|
||||||
state.currentLoraFilename.value = val
|
shouldAdvanceDisplay,
|
||||||
}
|
displayRepeatUsed: context.displayRepeatUsed,
|
||||||
if (output?.next_lora_name !== undefined) {
|
output: {
|
||||||
const val = Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name
|
nextIndex,
|
||||||
state.currentLoraName.value = val
|
nextLoraName,
|
||||||
}
|
nextLoraFilename,
|
||||||
if (output?.next_lora_filename !== undefined) {
|
currentLoraName,
|
||||||
const val = Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename
|
currentLoraFilename
|
||||||
state.currentLoraFilename.value = val
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Update visual feedback state (don't update displayRepeatUsed yet - wait for workflow completion)
|
||||||
|
state.executingRepeatStep.value = context.repeatUsed
|
||||||
|
state.isWorkflowExecuting.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call original onExecuted if it exists
|
// Call original onExecuted if it exists
|
||||||
@@ -194,11 +384,69 @@ onMounted(async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up execution tracking via API events
|
||||||
|
if (props.api) {
|
||||||
|
// Handle workflow completion events using FIFO order
|
||||||
|
// Note: The 'executing' event doesn't contain prompt_id (only node ID as string),
|
||||||
|
// so we use FIFO order instead of prompt_id matching since executions are processed
|
||||||
|
// in the order they were queued
|
||||||
|
const handleExecutionComplete = () => {
|
||||||
|
// Process the first pending execution (FIFO order)
|
||||||
|
if (pendingExecutions.length === 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const pending = pendingExecutions.shift()!
|
||||||
|
|
||||||
|
// Apply UI update now that workflow is complete
|
||||||
|
// Update repeat display (deferred like index updates)
|
||||||
|
state.displayRepeatUsed.value = pending.displayRepeatUsed
|
||||||
|
|
||||||
|
if (pending.output) {
|
||||||
|
if (pending.shouldAdvanceDisplay) {
|
||||||
|
state.currentIndex.value = pending.output.nextIndex
|
||||||
|
state.currentLoraName.value = pending.output.nextLoraName
|
||||||
|
state.currentLoraFilename.value = pending.output.nextLoraFilename
|
||||||
|
} else {
|
||||||
|
// When not advancing, show current LoRA info
|
||||||
|
state.currentLoraName.value = pending.output.currentLoraName
|
||||||
|
state.currentLoraFilename.value = pending.output.currentLoraFilename
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset visual feedback if no more pending
|
||||||
|
if (pendingExecutions.length === 0) {
|
||||||
|
state.isWorkflowExecuting.value = false
|
||||||
|
state.executingRepeatStep.value = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
props.api.addEventListener('execution_success', handleExecutionComplete)
|
||||||
|
props.api.addEventListener('execution_error', handleExecutionComplete)
|
||||||
|
props.api.addEventListener('execution_interrupted', handleExecutionComplete)
|
||||||
|
|
||||||
|
// Store cleanup function for API listeners
|
||||||
|
const apiCleanup = () => {
|
||||||
|
props.api.removeEventListener('execution_success', handleExecutionComplete)
|
||||||
|
props.api.removeEventListener('execution_error', handleExecutionComplete)
|
||||||
|
props.api.removeEventListener('execution_interrupted', handleExecutionComplete)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend existing cleanup
|
||||||
|
const existingCleanup = (props.widget as any).onRemoveCleanup
|
||||||
|
;(props.widget as any).onRemoveCleanup = () => {
|
||||||
|
existingCleanup?.()
|
||||||
|
apiCleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Watch for connection changes by polling (since ComfyUI doesn't provide connection events)
|
// Watch for connection changes by polling (since ComfyUI doesn't provide connection events)
|
||||||
const checkInterval = setInterval(checkPoolConfigChanges, 1000)
|
const checkInterval = setInterval(checkPoolConfigChanges, 1000)
|
||||||
|
|
||||||
// Cleanup on unmount (handled by Vue's effect scope)
|
// Cleanup on unmount (handled by Vue's effect scope)
|
||||||
|
const existingCleanupForInterval = (props.widget as any).onRemoveCleanup
|
||||||
;(props.widget as any).onRemoveCleanup = () => {
|
;(props.widget as any).onRemoveCleanup = () => {
|
||||||
|
existingCleanupForInterval?.()
|
||||||
clearInterval(checkInterval)
|
clearInterval(checkInterval)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -6,57 +6,111 @@
|
|||||||
|
|
||||||
<!-- Progress Display -->
|
<!-- Progress Display -->
|
||||||
<div class="setting-section progress-section">
|
<div class="setting-section progress-section">
|
||||||
<div class="progress-display">
|
<div class="progress-display" :class="{ executing: isWorkflowExecuting }">
|
||||||
<div class="progress-info">
|
<div
|
||||||
<span class="progress-label">Next LoRA:</span>
|
class="progress-info"
|
||||||
<span class="progress-name" :title="currentLoraFilename">{{ currentLoraName || 'None' }}</span>
|
:class="{ disabled: isPauseDisabled }"
|
||||||
|
@click="handleOpenSelector"
|
||||||
|
>
|
||||||
|
<span class="progress-label">{{ isWorkflowExecuting ? 'Using LoRA:' : 'Next LoRA:' }}</span>
|
||||||
|
<span class="progress-name clickable" :class="{ disabled: isPauseDisabled }" :title="currentLoraFilename">
|
||||||
|
{{ currentLoraName || 'None' }}
|
||||||
|
<svg class="selector-icon" viewBox="0 0 24 24" fill="currentColor">
|
||||||
|
<path d="M7 10l5 5 5-5z"/>
|
||||||
|
</svg>
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="progress-counter">
|
<div class="progress-counter">
|
||||||
<span class="progress-index">{{ currentIndex }}</span>
|
<span class="progress-index">{{ currentIndex }}</span>
|
||||||
<span class="progress-separator">/</span>
|
<span class="progress-separator">/</span>
|
||||||
<span class="progress-total">{{ totalCount }}</span>
|
<span class="progress-total">{{ totalCount }}</span>
|
||||||
<button
|
|
||||||
class="refresh-button"
|
<!-- Repeat progress indicator (only shown when repeatCount > 1) -->
|
||||||
:disabled="isLoading"
|
<div v-if="repeatCount > 1" class="repeat-progress">
|
||||||
@click="$emit('refresh')"
|
<div class="repeat-progress-track">
|
||||||
title="Refresh list"
|
<div
|
||||||
>
|
class="repeat-progress-fill"
|
||||||
<svg
|
:style="{ width: `${(repeatUsed / repeatCount) * 100}%` }"
|
||||||
class="refresh-icon"
|
:class="{ 'is-complete': repeatUsed >= repeatCount }"
|
||||||
:class="{ spinning: isLoading }"
|
></div>
|
||||||
viewBox="0 0 24 24"
|
</div>
|
||||||
fill="none"
|
<span class="repeat-progress-text">{{ repeatUsed }}/{{ repeatCount }}</span>
|
||||||
stroke="currentColor"
|
</div>
|
||||||
stroke-width="2"
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
>
|
|
||||||
<path d="M21 12a9 9 0 1 1-6.219-8.56"/>
|
|
||||||
<path d="M21 3v5h-5"/>
|
|
||||||
</svg>
|
|
||||||
</button>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Starting Index -->
|
<!-- Starting Index with Advanced Controls -->
|
||||||
<div class="setting-section">
|
<div class="setting-section">
|
||||||
<label class="setting-label">Starting Index</label>
|
<div class="index-controls-row">
|
||||||
<div class="index-input-container">
|
<!-- Left: Index group -->
|
||||||
<input
|
<div class="control-group">
|
||||||
type="number"
|
<label class="control-group-label">Starting Index</label>
|
||||||
class="index-input"
|
<div class="control-group-content">
|
||||||
:min="1"
|
<input
|
||||||
:max="totalCount || 1"
|
type="number"
|
||||||
:value="currentIndex"
|
class="index-input"
|
||||||
:disabled="totalCount === 0"
|
:min="1"
|
||||||
@input="onIndexInput"
|
:max="totalCount || 1"
|
||||||
@blur="onIndexBlur"
|
:value="currentIndex"
|
||||||
@pointerdown.stop
|
:disabled="totalCount === 0"
|
||||||
@pointermove.stop
|
@input="onIndexInput"
|
||||||
@pointerup.stop
|
@blur="onIndexBlur"
|
||||||
/>
|
@pointerdown.stop
|
||||||
<span class="index-hint">1 - {{ totalCount || 1 }}</span>
|
@pointermove.stop
|
||||||
|
@pointerup.stop
|
||||||
|
/>
|
||||||
|
<span class="index-hint">/ {{ totalCount || 1 }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Right: Repeat group -->
|
||||||
|
<div class="control-group">
|
||||||
|
<label class="control-group-label">Repeat</label>
|
||||||
|
<div class="control-group-content">
|
||||||
|
<input
|
||||||
|
type="number"
|
||||||
|
class="repeat-input"
|
||||||
|
min="1"
|
||||||
|
max="99"
|
||||||
|
:value="repeatCount"
|
||||||
|
@input="onRepeatInput"
|
||||||
|
@blur="onRepeatBlur"
|
||||||
|
@pointerdown.stop
|
||||||
|
@pointermove.stop
|
||||||
|
@pointerup.stop
|
||||||
|
title="Each LoRA will be used this many times before moving to the next"
|
||||||
|
/>
|
||||||
|
<span class="repeat-suffix">×</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Action buttons -->
|
||||||
|
<div class="action-buttons">
|
||||||
|
<button
|
||||||
|
class="control-btn"
|
||||||
|
:class="{ active: isPaused }"
|
||||||
|
:disabled="isPauseDisabled"
|
||||||
|
@click="$emit('toggle-pause')"
|
||||||
|
:title="isPauseDisabled ? 'Cannot pause while prompts are queued' : (isPaused ? 'Continue iteration' : 'Pause iteration')"
|
||||||
|
>
|
||||||
|
<svg v-if="isPaused" viewBox="0 0 24 24" fill="currentColor" class="control-icon">
|
||||||
|
<path d="M8 5v14l11-7z"/>
|
||||||
|
</svg>
|
||||||
|
<svg v-else viewBox="0 0 24 24" fill="currentColor" class="control-icon">
|
||||||
|
<path d="M6 4h4v16H6zm8 0h4v16h-4z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
class="control-btn"
|
||||||
|
@click="$emit('reset-index')"
|
||||||
|
title="Reset to index 1"
|
||||||
|
>
|
||||||
|
<svg viewBox="0 0 24 24" fill="currentColor" class="control-icon">
|
||||||
|
<path d="M12 5V1L7 6l5 5V7c3.31 0 6 2.69 6 6s-2.69 6-6 6-6-2.69-6-6H4c0 4.42 3.58 8 8 8s8-3.58 8-8-3.58-8-8-8z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -122,7 +176,12 @@ const props = defineProps<{
|
|||||||
clipStrength: number
|
clipStrength: number
|
||||||
useCustomClipRange: boolean
|
useCustomClipRange: boolean
|
||||||
isClipStrengthDisabled: boolean
|
isClipStrengthDisabled: boolean
|
||||||
isLoading: boolean
|
repeatCount: number
|
||||||
|
repeatUsed: number
|
||||||
|
isPaused: boolean
|
||||||
|
isPauseDisabled: boolean
|
||||||
|
isWorkflowExecuting: boolean
|
||||||
|
executingRepeatStep: number
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
@@ -130,11 +189,22 @@ const emit = defineEmits<{
|
|||||||
'update:modelStrength': [value: number]
|
'update:modelStrength': [value: number]
|
||||||
'update:clipStrength': [value: number]
|
'update:clipStrength': [value: number]
|
||||||
'update:useCustomClipRange': [value: boolean]
|
'update:useCustomClipRange': [value: boolean]
|
||||||
'refresh': []
|
'update:repeatCount': [value: number]
|
||||||
|
'toggle-pause': []
|
||||||
|
'reset-index': []
|
||||||
|
'open-lora-selector': []
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
// Temporary value for input while typing
|
// Temporary value for input while typing
|
||||||
const tempIndex = ref<string>('')
|
const tempIndex = ref<string>('')
|
||||||
|
const tempRepeat = ref<string>('')
|
||||||
|
|
||||||
|
const handleOpenSelector = () => {
|
||||||
|
if (props.isPauseDisabled) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
emit('open-lora-selector')
|
||||||
|
}
|
||||||
|
|
||||||
const onIndexInput = (event: Event) => {
|
const onIndexInput = (event: Event) => {
|
||||||
const input = event.target as HTMLInputElement
|
const input = event.target as HTMLInputElement
|
||||||
@@ -154,6 +224,25 @@ const onIndexBlur = (event: Event) => {
|
|||||||
}
|
}
|
||||||
tempIndex.value = ''
|
tempIndex.value = ''
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onRepeatInput = (event: Event) => {
|
||||||
|
const input = event.target as HTMLInputElement
|
||||||
|
tempRepeat.value = input.value
|
||||||
|
}
|
||||||
|
|
||||||
|
const onRepeatBlur = (event: Event) => {
|
||||||
|
const input = event.target as HTMLInputElement
|
||||||
|
const value = parseInt(input.value, 10)
|
||||||
|
|
||||||
|
if (!isNaN(value)) {
|
||||||
|
const clampedValue = Math.max(1, Math.min(value, 99))
|
||||||
|
emit('update:repeatCount', clampedValue)
|
||||||
|
input.value = clampedValue.toString()
|
||||||
|
} else {
|
||||||
|
input.value = props.repeatCount.toString()
|
||||||
|
}
|
||||||
|
tempRepeat.value = ''
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
@@ -203,6 +292,17 @@ const onIndexBlur = (event: Event) => {
|
|||||||
display: flex;
|
display: flex;
|
||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
transition: border-color 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-display.executing {
|
||||||
|
border-color: rgba(66, 153, 225, 0.5);
|
||||||
|
animation: pulse 2s ease-in-out infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% { border-color: rgba(66, 153, 225, 0.3); }
|
||||||
|
50% { border-color: rgba(66, 153, 225, 0.7); }
|
||||||
}
|
}
|
||||||
|
|
||||||
.progress-info {
|
.progress-info {
|
||||||
@@ -230,6 +330,42 @@ const onIndexBlur = (event: Event) => {
|
|||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.progress-name.clickable {
|
||||||
|
cursor: pointer;
|
||||||
|
padding: 2px 6px;
|
||||||
|
margin: -2px -6px;
|
||||||
|
border-radius: 4px;
|
||||||
|
transition: all 0.2s;
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-name.clickable:hover:not(.disabled) {
|
||||||
|
background: rgba(66, 153, 225, 0.2);
|
||||||
|
color: rgba(191, 219, 254, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-name.clickable.disabled {
|
||||||
|
cursor: not-allowed;
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-info.disabled {
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.selector-icon {
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
opacity: 0.5;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-name.clickable:hover .selector-icon {
|
||||||
|
opacity: 0.8;
|
||||||
|
}
|
||||||
|
|
||||||
.progress-counter {
|
.progress-counter {
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
@@ -243,6 +379,9 @@ const onIndexBlur = (event: Event) => {
|
|||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
color: rgba(66, 153, 225, 1);
|
color: rgba(66, 153, 225, 1);
|
||||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||||
|
min-width: 4ch;
|
||||||
|
text-align: right;
|
||||||
|
font-variant-numeric: tabular-nums;
|
||||||
}
|
}
|
||||||
|
|
||||||
.progress-separator {
|
.progress-separator {
|
||||||
@@ -256,69 +395,92 @@ const onIndexBlur = (event: Event) => {
|
|||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
color: rgba(226, 232, 240, 0.6);
|
color: rgba(226, 232, 240, 0.6);
|
||||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||||
|
min-width: 4ch;
|
||||||
|
text-align: left;
|
||||||
|
font-variant-numeric: tabular-nums;
|
||||||
}
|
}
|
||||||
|
|
||||||
.refresh-button {
|
/* Repeat Progress */
|
||||||
|
.repeat-progress {
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: center;
|
gap: 6px;
|
||||||
width: 24px;
|
|
||||||
height: 24px;
|
|
||||||
margin-left: 8px;
|
margin-left: 8px;
|
||||||
padding: 0;
|
padding: 2px 6px;
|
||||||
background: transparent;
|
background: rgba(26, 32, 44, 0.6);
|
||||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
border: 1px solid rgba(226, 232, 240, 0.1);
|
||||||
border-radius: 4px;
|
border-radius: 4px;
|
||||||
color: rgba(226, 232, 240, 0.6);
|
|
||||||
cursor: pointer;
|
|
||||||
transition: all 0.2s;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.refresh-button:hover:not(:disabled) {
|
.repeat-progress-track {
|
||||||
background: rgba(66, 153, 225, 0.2);
|
width: 32px;
|
||||||
border-color: rgba(66, 153, 225, 0.4);
|
height: 4px;
|
||||||
color: rgba(191, 219, 254, 1);
|
background: rgba(226, 232, 240, 0.15);
|
||||||
|
border-radius: 2px;
|
||||||
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
.refresh-button:disabled {
|
.repeat-progress-fill {
|
||||||
opacity: 0.4;
|
height: 100%;
|
||||||
cursor: not-allowed;
|
background: linear-gradient(90deg, #f59e0b, #fbbf24);
|
||||||
|
border-radius: 2px;
|
||||||
|
transition: width 0.3s ease;
|
||||||
}
|
}
|
||||||
|
|
||||||
.refresh-icon {
|
.repeat-progress-fill.is-complete {
|
||||||
width: 14px;
|
background: linear-gradient(90deg, #10b981, #34d399);
|
||||||
height: 14px;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.refresh-icon.spinning {
|
.repeat-progress-text {
|
||||||
animation: spin 1s linear infinite;
|
font-size: 10px;
|
||||||
|
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||||
|
color: rgba(253, 230, 138, 0.9);
|
||||||
|
min-width: 3ch;
|
||||||
|
font-variant-numeric: tabular-nums;
|
||||||
}
|
}
|
||||||
|
|
||||||
@keyframes spin {
|
/* Index Controls Row - Grouped Layout */
|
||||||
from {
|
.index-controls-row {
|
||||||
transform: rotate(0deg);
|
|
||||||
}
|
|
||||||
to {
|
|
||||||
transform: rotate(360deg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Index Input */
|
|
||||||
.index-input-container {
|
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: flex-end;
|
||||||
gap: 8px;
|
gap: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Control Group */
|
||||||
|
.control-group {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-group-label {
|
||||||
|
font-size: 11px;
|
||||||
|
font-weight: 500;
|
||||||
|
color: rgba(226, 232, 240, 0.5);
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.03em;
|
||||||
|
line-height: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-group-content {
|
||||||
|
display: flex;
|
||||||
|
align-items: baseline;
|
||||||
|
gap: 4px;
|
||||||
|
height: 32px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.index-input {
|
.index-input {
|
||||||
width: 80px;
|
width: 50px;
|
||||||
padding: 6px 10px;
|
height: 32px;
|
||||||
|
padding: 0 8px;
|
||||||
background: rgba(26, 32, 44, 0.9);
|
background: rgba(26, 32, 44, 0.9);
|
||||||
border: 1px solid rgba(226, 232, 240, 0.2);
|
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||||
border-radius: 6px;
|
border-radius: 6px;
|
||||||
color: #e4e4e7;
|
color: #e4e4e7;
|
||||||
font-size: 13px;
|
font-size: 13px;
|
||||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||||
|
line-height: 32px;
|
||||||
|
box-sizing: border-box;
|
||||||
}
|
}
|
||||||
|
|
||||||
.index-input:focus {
|
.index-input:focus {
|
||||||
@@ -332,8 +494,89 @@ const onIndexBlur = (event: Event) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.index-hint {
|
.index-hint {
|
||||||
font-size: 11px;
|
font-size: 12px;
|
||||||
color: rgba(226, 232, 240, 0.4);
|
color: rgba(226, 232, 240, 0.4);
|
||||||
|
font-variant-numeric: tabular-nums;
|
||||||
|
line-height: 32px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Repeat Controls */
|
||||||
|
.repeat-input {
|
||||||
|
width: 40px;
|
||||||
|
height: 32px;
|
||||||
|
padding: 0 6px;
|
||||||
|
background: rgba(26, 32, 44, 0.9);
|
||||||
|
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||||
|
border-radius: 6px;
|
||||||
|
color: #e4e4e7;
|
||||||
|
font-size: 13px;
|
||||||
|
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||||
|
text-align: center;
|
||||||
|
line-height: 32px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.repeat-input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: rgba(66, 153, 225, 0.6);
|
||||||
|
}
|
||||||
|
|
||||||
|
.repeat-suffix {
|
||||||
|
font-size: 13px;
|
||||||
|
color: rgba(226, 232, 240, 0.4);
|
||||||
|
font-weight: 500;
|
||||||
|
line-height: 32px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Action Buttons */
|
||||||
|
.action-buttons {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
margin-left: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Control Buttons */
|
||||||
|
.control-btn {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 24px;
|
||||||
|
height: 24px;
|
||||||
|
padding: 0;
|
||||||
|
background: transparent;
|
||||||
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||||
|
border-radius: 4px;
|
||||||
|
color: rgba(226, 232, 240, 0.6);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-btn:hover:not(:disabled) {
|
||||||
|
background: rgba(66, 153, 225, 0.2);
|
||||||
|
border-color: rgba(66, 153, 225, 0.4);
|
||||||
|
color: rgba(191, 219, 254, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-btn:disabled {
|
||||||
|
opacity: 0.4;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-btn.active {
|
||||||
|
background: rgba(245, 158, 11, 0.2);
|
||||||
|
border-color: rgba(245, 158, 11, 0.5);
|
||||||
|
color: rgba(253, 230, 138, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-btn.active:hover {
|
||||||
|
background: rgba(245, 158, 11, 0.3);
|
||||||
|
border-color: rgba(245, 158, 11, 0.6);
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-icon {
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Slider Container */
|
/* Slider Container */
|
||||||
|
|||||||
313
vue-widgets/src/components/lora-cycler/LoraListModal.vue
Normal file
313
vue-widgets/src/components/lora-cycler/LoraListModal.vue
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
<template>
|
||||||
|
<ModalWrapper
|
||||||
|
:visible="visible"
|
||||||
|
title="Select LoRA"
|
||||||
|
:subtitle="subtitleText"
|
||||||
|
@close="$emit('close')"
|
||||||
|
>
|
||||||
|
<template #search>
|
||||||
|
<div class="search-container">
|
||||||
|
<svg class="search-icon" viewBox="0 0 16 16" fill="currentColor">
|
||||||
|
<path d="M11.742 10.344a6.5 6.5 0 1 0-1.397 1.398h-.001c.03.04.062.078.098.115l3.85 3.85a1 1 0 0 0 1.415-1.414l-3.85-3.85a1.007 1.007 0 0 0-.115-.1zM12 6.5a5.5 5.5 0 1 1-11 0 5.5 5.5 0 0 1 11 0z"/>
|
||||||
|
</svg>
|
||||||
|
<input
|
||||||
|
ref="searchInputRef"
|
||||||
|
v-model="searchQuery"
|
||||||
|
type="text"
|
||||||
|
class="search-input"
|
||||||
|
placeholder="Search LoRAs..."
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
v-if="searchQuery"
|
||||||
|
type="button"
|
||||||
|
class="clear-button"
|
||||||
|
@click="clearSearch"
|
||||||
|
>
|
||||||
|
<svg viewBox="0 0 16 16" fill="currentColor">
|
||||||
|
<path d="M4.646 4.646a.5.5 0 0 1 .708 0L8 7.293l2.646-2.647a.5.5 0 0 1 .708.708L8.707 8l2.647 2.646a.5.5 0 0 1-.708.708L8 8.707l-2.646 2.647a.5.5 0 0 1-.708-.708L7.293 8 4.646 5.354a.5.5 0 0 1 0-.708z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<div class="lora-list">
|
||||||
|
<div
|
||||||
|
v-for="item in filteredList"
|
||||||
|
:key="item.index"
|
||||||
|
class="lora-item"
|
||||||
|
:class="{ active: currentIndex === item.index }"
|
||||||
|
@mouseenter="showPreview(item.lora.file_name, $event)"
|
||||||
|
@mouseleave="hidePreview"
|
||||||
|
@click="selectLora(item.index)"
|
||||||
|
>
|
||||||
|
<span class="lora-index">{{ item.index }}</span>
|
||||||
|
<span class="lora-name" :title="item.lora.file_name">{{ item.lora.file_name }}</span>
|
||||||
|
<span v-if="currentIndex === item.index" class="current-badge">Current</span>
|
||||||
|
</div>
|
||||||
|
<div v-if="filteredList.length === 0" class="no-results">
|
||||||
|
No LoRAs found
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</ModalWrapper>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed, watch, nextTick, onUnmounted } from 'vue'
|
||||||
|
import ModalWrapper from '../lora-pool/modals/ModalWrapper.vue'
|
||||||
|
import type { LoraItem } from '../../composables/types'
|
||||||
|
|
||||||
|
interface LoraListItem {
|
||||||
|
index: number
|
||||||
|
lora: LoraItem
|
||||||
|
}
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
visible: boolean
|
||||||
|
loraList: LoraItem[]
|
||||||
|
currentIndex: number
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
close: []
|
||||||
|
select: [index: number]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const searchQuery = ref('')
|
||||||
|
const searchInputRef = ref<HTMLInputElement | null>(null)
|
||||||
|
|
||||||
|
// Preview tooltip instance (lazy init)
|
||||||
|
let previewTooltip: any = null
|
||||||
|
|
||||||
|
const subtitleText = computed(() => {
|
||||||
|
const total = props.loraList.length
|
||||||
|
const filtered = filteredList.value.length
|
||||||
|
if (filtered === total) {
|
||||||
|
return `Total: ${total} LoRA${total !== 1 ? 's' : ''}`
|
||||||
|
}
|
||||||
|
return `Showing ${filtered} of ${total} LoRA${total !== 1 ? 's' : ''}`
|
||||||
|
})
|
||||||
|
|
||||||
|
const filteredList = computed<LoraListItem[]>(() => {
|
||||||
|
const list = props.loraList.map((lora, idx) => ({
|
||||||
|
index: idx + 1,
|
||||||
|
lora
|
||||||
|
}))
|
||||||
|
|
||||||
|
if (!searchQuery.value.trim()) {
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
const query = searchQuery.value.toLowerCase()
|
||||||
|
return list.filter(item =>
|
||||||
|
item.lora.file_name.toLowerCase().includes(query)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
const clearSearch = () => {
|
||||||
|
searchQuery.value = ''
|
||||||
|
searchInputRef.value?.focus()
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectLora = (index: number) => {
|
||||||
|
emit('select', index)
|
||||||
|
emit('close')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom preview URL resolver for Vue widgets environment
|
||||||
|
// The default preview_tooltip.js uses api.fetchApi which is mocked as native fetch
|
||||||
|
// in the Vue widgets build, so we need to use the full path with /api prefix
|
||||||
|
const customPreviewUrlResolver = async (modelName: string) => {
|
||||||
|
const response = await fetch(
|
||||||
|
`/api/lm/loras/preview-url?name=${encodeURIComponent(modelName)}&license_flags=true`
|
||||||
|
)
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to fetch preview URL')
|
||||||
|
}
|
||||||
|
const data = await response.json()
|
||||||
|
if (!data.success || !data.preview_url) {
|
||||||
|
throw new Error('No preview available')
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
previewUrl: data.preview_url,
|
||||||
|
displayName: data.display_name ?? modelName,
|
||||||
|
licenseFlags: data.license_flags
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lazy load PreviewTooltip to avoid loading it unnecessarily
|
||||||
|
const getPreviewTooltip = async () => {
|
||||||
|
if (!previewTooltip) {
|
||||||
|
const { PreviewTooltip } = await import(/* @vite-ignore */ `${'../preview_tooltip.js'}`)
|
||||||
|
previewTooltip = new PreviewTooltip({
|
||||||
|
modelType: 'loras',
|
||||||
|
displayNameFormatter: (name: string) => name,
|
||||||
|
previewUrlResolver: customPreviewUrlResolver
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return previewTooltip
|
||||||
|
}
|
||||||
|
|
||||||
|
const showPreview = async (loraName: string, event: MouseEvent) => {
|
||||||
|
const tooltip = await getPreviewTooltip()
|
||||||
|
const rect = (event.target as HTMLElement).getBoundingClientRect()
|
||||||
|
// Position to the right of the item, centered vertically
|
||||||
|
tooltip.show(loraName, rect.right + 10, rect.top + rect.height / 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
const hidePreview = async () => {
|
||||||
|
if (previewTooltip) {
|
||||||
|
previewTooltip.hide()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Focus search input when modal opens
|
||||||
|
watch(() => props.visible, (isVisible) => {
|
||||||
|
if (isVisible) {
|
||||||
|
searchQuery.value = ''
|
||||||
|
nextTick(() => {
|
||||||
|
searchInputRef.value?.focus()
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Hide preview when modal closes
|
||||||
|
hidePreview()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Cleanup on unmount
|
||||||
|
onUnmounted(() => {
|
||||||
|
if (previewTooltip) {
|
||||||
|
previewTooltip.cleanup()
|
||||||
|
previewTooltip = null
|
||||||
|
}
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.search-container {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-icon {
|
||||||
|
position: absolute;
|
||||||
|
left: 10px;
|
||||||
|
top: 50%;
|
||||||
|
transform: translateY(-50%);
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
|
color: var(--fg-color, #fff);
|
||||||
|
opacity: 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 8px 32px;
|
||||||
|
background: var(--comfy-input-bg, #333);
|
||||||
|
border: 1px solid var(--border-color, #444);
|
||||||
|
border-radius: 6px;
|
||||||
|
color: var(--fg-color, #fff);
|
||||||
|
font-size: 13px;
|
||||||
|
outline: none;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-input:focus {
|
||||||
|
border-color: rgba(66, 153, 225, 0.6);
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-input::placeholder {
|
||||||
|
color: var(--fg-color, #fff);
|
||||||
|
opacity: 0.4;
|
||||||
|
}
|
||||||
|
|
||||||
|
.clear-button {
|
||||||
|
position: absolute;
|
||||||
|
right: 8px;
|
||||||
|
top: 50%;
|
||||||
|
transform: translateY(-50%);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
background: transparent;
|
||||||
|
border: none;
|
||||||
|
cursor: pointer;
|
||||||
|
padding: 0;
|
||||||
|
opacity: 0.5;
|
||||||
|
transition: opacity 0.15s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.clear-button:hover {
|
||||||
|
opacity: 0.8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.clear-button svg {
|
||||||
|
width: 12px;
|
||||||
|
height: 12px;
|
||||||
|
color: var(--fg-color, #fff);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lora-list {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 2px;
|
||||||
|
max-height: 400px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lora-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 12px;
|
||||||
|
padding: 10px 12px;
|
||||||
|
border-radius: 6px;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.15s;
|
||||||
|
border-left: 3px solid transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lora-item:hover {
|
||||||
|
background: rgba(66, 153, 225, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lora-item.active {
|
||||||
|
background: rgba(66, 153, 225, 0.25);
|
||||||
|
border-left-color: rgba(66, 153, 225, 0.8);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lora-index {
|
||||||
|
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||||
|
font-size: 12px;
|
||||||
|
color: rgba(226, 232, 240, 0.5);
|
||||||
|
min-width: 3ch;
|
||||||
|
text-align: right;
|
||||||
|
font-variant-numeric: tabular-nums;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lora-name {
|
||||||
|
flex: 1;
|
||||||
|
font-size: 13px;
|
||||||
|
color: var(--fg-color, #fff);
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.current-badge {
|
||||||
|
font-size: 11px;
|
||||||
|
padding: 2px 8px;
|
||||||
|
background: rgba(66, 153, 225, 0.3);
|
||||||
|
border: 1px solid rgba(66, 153, 225, 0.5);
|
||||||
|
border-radius: 4px;
|
||||||
|
color: rgba(191, 219, 254, 1);
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
.no-results {
|
||||||
|
padding: 32px 20px;
|
||||||
|
text-align: center;
|
||||||
|
color: var(--fg-color, #fff);
|
||||||
|
opacity: 0.5;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -81,7 +81,7 @@ watch(() => props.visible, (isVisible) => {
|
|||||||
.lora-pool-modal-backdrop {
|
.lora-pool-modal-backdrop {
|
||||||
position: fixed;
|
position: fixed;
|
||||||
inset: 0;
|
inset: 0;
|
||||||
z-index: 10000;
|
z-index: 9998;
|
||||||
background: rgba(0, 0, 0, 0.6);
|
background: rgba(0, 0, 0, 0.6);
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
|||||||
@@ -206,7 +206,9 @@ const stepToDecimals = (step: number): number => {
|
|||||||
const snapToStep = (value: number, segmentMultiplier?: number): number => {
|
const snapToStep = (value: number, segmentMultiplier?: number): number => {
|
||||||
const effectiveStep = segmentMultiplier ? props.step * segmentMultiplier : props.step
|
const effectiveStep = segmentMultiplier ? props.step * segmentMultiplier : props.step
|
||||||
const steps = Math.round((value - props.min) / effectiveStep)
|
const steps = Math.round((value - props.min) / effectiveStep)
|
||||||
return Math.max(props.min, Math.min(props.max, props.min + steps * effectiveStep))
|
const rawValue = Math.max(props.min, Math.min(props.max, props.min + steps * effectiveStep))
|
||||||
|
// Fix floating point precision issues, limit to 2 decimal places
|
||||||
|
return Math.round(rawValue * 100) / 100
|
||||||
}
|
}
|
||||||
|
|
||||||
const startDrag = (handle: 'min' | 'max', event: PointerEvent) => {
|
const startDrag = (handle: 'min' | 'max', event: PointerEvent) => {
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ const stepToDecimals = (step: number): number => {
|
|||||||
|
|
||||||
const snapToStep = (value: number): number => {
|
const snapToStep = (value: number): number => {
|
||||||
const steps = Math.round((value - props.min) / props.step)
|
const steps = Math.round((value - props.min) / props.step)
|
||||||
return Math.max(props.min, Math.min(props.max, props.min + steps * props.step))
|
const rawValue = Math.max(props.min, Math.min(props.max, props.min + steps * props.step))
|
||||||
|
// Fix floating point precision issues, limit to 2 decimal places
|
||||||
|
return Math.round(rawValue * 100) / 100
|
||||||
}
|
}
|
||||||
|
|
||||||
const startDrag = (event: PointerEvent) => {
|
const startDrag = (event: PointerEvent) => {
|
||||||
|
|||||||
@@ -80,6 +80,10 @@ export interface CyclerConfig {
|
|||||||
// Dual-index mechanism for batch queue synchronization
|
// Dual-index mechanism for batch queue synchronization
|
||||||
execution_index?: number | null // Index to use for current execution
|
execution_index?: number | null // Index to use for current execution
|
||||||
next_index?: number | null // Index for display after execution
|
next_index?: number | null // Index for display after execution
|
||||||
|
// Advanced index control features
|
||||||
|
repeat_count: number // How many times each LoRA should repeat (default: 1)
|
||||||
|
repeat_used: number // How many times current index has been used
|
||||||
|
is_paused: boolean // Whether iteration is paused
|
||||||
}
|
}
|
||||||
|
|
||||||
// Widget config union type
|
// Widget config union type
|
||||||
|
|||||||
@@ -29,6 +29,16 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
const executionIndex = ref<number | null>(null)
|
const executionIndex = ref<number | null>(null)
|
||||||
const nextIndex = ref<number | null>(null)
|
const nextIndex = ref<number | null>(null)
|
||||||
|
|
||||||
|
// Advanced index control features
|
||||||
|
const repeatCount = ref(1) // How many times each LoRA should repeat
|
||||||
|
const repeatUsed = ref(0) // How many times current index has been used (internal tracking)
|
||||||
|
const displayRepeatUsed = ref(0) // For UI display, deferred updates like currentIndex
|
||||||
|
const isPaused = ref(false) // Whether iteration is paused
|
||||||
|
|
||||||
|
// Execution progress tracking (visual feedback)
|
||||||
|
const isWorkflowExecuting = ref(false) // Workflow is currently running
|
||||||
|
const executingRepeatStep = ref(0) // Which repeat step (1-based, 0 = not executing)
|
||||||
|
|
||||||
// Build config object from current state
|
// Build config object from current state
|
||||||
const buildConfig = (): CyclerConfig => {
|
const buildConfig = (): CyclerConfig => {
|
||||||
// Skip updating widget.value during restoration to prevent infinite loops
|
// Skip updating widget.value during restoration to prevent infinite loops
|
||||||
@@ -45,6 +55,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
current_lora_filename: currentLoraFilename.value,
|
current_lora_filename: currentLoraFilename.value,
|
||||||
execution_index: executionIndex.value,
|
execution_index: executionIndex.value,
|
||||||
next_index: nextIndex.value,
|
next_index: nextIndex.value,
|
||||||
|
repeat_count: repeatCount.value,
|
||||||
|
repeat_used: repeatUsed.value,
|
||||||
|
is_paused: isPaused.value,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
@@ -59,6 +72,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
current_lora_filename: currentLoraFilename.value,
|
current_lora_filename: currentLoraFilename.value,
|
||||||
execution_index: executionIndex.value,
|
execution_index: executionIndex.value,
|
||||||
next_index: nextIndex.value,
|
next_index: nextIndex.value,
|
||||||
|
repeat_count: repeatCount.value,
|
||||||
|
repeat_used: repeatUsed.value,
|
||||||
|
is_paused: isPaused.value,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +93,10 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
sortBy.value = config.sort_by || 'filename'
|
sortBy.value = config.sort_by || 'filename'
|
||||||
currentLoraName.value = config.current_lora_name || ''
|
currentLoraName.value = config.current_lora_name || ''
|
||||||
currentLoraFilename.value = config.current_lora_filename || ''
|
currentLoraFilename.value = config.current_lora_filename || ''
|
||||||
|
// Advanced index control features
|
||||||
|
repeatCount.value = config.repeat_count ?? 1
|
||||||
|
repeatUsed.value = config.repeat_used ?? 0
|
||||||
|
isPaused.value = config.is_paused ?? false
|
||||||
// Note: execution_index and next_index are not restored from config
|
// Note: execution_index and next_index are not restored from config
|
||||||
// as they are transient values used only during batch execution
|
// as they are transient values used only during batch execution
|
||||||
} finally {
|
} finally {
|
||||||
@@ -215,6 +235,19 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset index to 1 and clear repeat state
|
||||||
|
const resetIndex = () => {
|
||||||
|
currentIndex.value = 1
|
||||||
|
repeatUsed.value = 0
|
||||||
|
displayRepeatUsed.value = 0
|
||||||
|
// Note: isPaused is intentionally not reset - user may want to stay paused after reset
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle pause state
|
||||||
|
const togglePause = () => {
|
||||||
|
isPaused.value = !isPaused.value
|
||||||
|
}
|
||||||
|
|
||||||
// Computed property to check if clip strength is disabled
|
// Computed property to check if clip strength is disabled
|
||||||
const isClipStrengthDisabled = computed(() => !useCustomClipRange.value)
|
const isClipStrengthDisabled = computed(() => !useCustomClipRange.value)
|
||||||
|
|
||||||
@@ -236,6 +269,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
sortBy,
|
sortBy,
|
||||||
currentLoraName,
|
currentLoraName,
|
||||||
currentLoraFilename,
|
currentLoraFilename,
|
||||||
|
repeatCount,
|
||||||
|
repeatUsed,
|
||||||
|
isPaused,
|
||||||
], () => {
|
], () => {
|
||||||
widget.value = buildConfig()
|
widget.value = buildConfig()
|
||||||
}, { deep: true })
|
}, { deep: true })
|
||||||
@@ -254,6 +290,12 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
isLoading,
|
isLoading,
|
||||||
executionIndex,
|
executionIndex,
|
||||||
nextIndex,
|
nextIndex,
|
||||||
|
repeatCount,
|
||||||
|
repeatUsed,
|
||||||
|
displayRepeatUsed,
|
||||||
|
isPaused,
|
||||||
|
isWorkflowExecuting,
|
||||||
|
executingRepeatStep,
|
||||||
|
|
||||||
// Computed
|
// Computed
|
||||||
isClipStrengthDisabled,
|
isClipStrengthDisabled,
|
||||||
@@ -267,5 +309,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
|||||||
setIndex,
|
setIndex,
|
||||||
generateNextIndex,
|
generateNextIndex,
|
||||||
initializeNextIndex,
|
initializeNextIndex,
|
||||||
|
resetIndex,
|
||||||
|
togglePause,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ const AUTOCOMPLETE_TEXT_WIDGET_MAX_HEIGHT = 100
|
|||||||
|
|
||||||
// @ts-ignore - ComfyUI external module
|
// @ts-ignore - ComfyUI external module
|
||||||
import { app } from '../../../scripts/app.js'
|
import { app } from '../../../scripts/app.js'
|
||||||
|
// @ts-ignore - ComfyUI external module
|
||||||
|
import { api } from '../../../scripts/api.js'
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js'
|
import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js'
|
||||||
|
|
||||||
@@ -255,7 +257,8 @@ function createLoraCyclerWidget(node) {
|
|||||||
|
|
||||||
const vueApp = createApp(LoraCyclerWidget, {
|
const vueApp = createApp(LoraCyclerWidget, {
|
||||||
widget,
|
widget,
|
||||||
node
|
node,
|
||||||
|
api
|
||||||
})
|
})
|
||||||
|
|
||||||
vueApp.use(PrimeVue, {
|
vueApp.use(PrimeVue, {
|
||||||
|
|||||||
634
vue-widgets/tests/composables/useLoraCyclerState.test.ts
Normal file
634
vue-widgets/tests/composables/useLoraCyclerState.test.ts
Normal file
@@ -0,0 +1,634 @@
|
|||||||
|
/**
|
||||||
|
* Unit tests for useLoraCyclerState composable
|
||||||
|
*
|
||||||
|
* Tests pure state transitions and index calculations in isolation.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeEach, vi } from 'vitest'
|
||||||
|
import { useLoraCyclerState } from '@/composables/useLoraCyclerState'
|
||||||
|
import {
|
||||||
|
createMockWidget,
|
||||||
|
createMockCyclerConfig,
|
||||||
|
createMockPoolConfig
|
||||||
|
} from '../fixtures/mockConfigs'
|
||||||
|
import { setupFetchMock, resetFetchMock } from '../setup'
|
||||||
|
|
||||||
|
describe('useLoraCyclerState', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
resetFetchMock()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Initial State', () => {
|
||||||
|
it('should initialize with default values', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
expect(state.currentIndex.value).toBe(1)
|
||||||
|
expect(state.totalCount.value).toBe(0)
|
||||||
|
expect(state.poolConfigHash.value).toBe('')
|
||||||
|
expect(state.modelStrength.value).toBe(1.0)
|
||||||
|
expect(state.clipStrength.value).toBe(1.0)
|
||||||
|
expect(state.useCustomClipRange.value).toBe(false)
|
||||||
|
expect(state.sortBy.value).toBe('filename')
|
||||||
|
expect(state.executionIndex.value).toBeNull()
|
||||||
|
expect(state.nextIndex.value).toBeNull()
|
||||||
|
expect(state.repeatCount.value).toBe(1)
|
||||||
|
expect(state.repeatUsed.value).toBe(0)
|
||||||
|
expect(state.displayRepeatUsed.value).toBe(0)
|
||||||
|
expect(state.isPaused.value).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('restoreFromConfig', () => {
|
||||||
|
it('should restore state from config object', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const config = createMockCyclerConfig({
|
||||||
|
current_index: 3,
|
||||||
|
total_count: 10,
|
||||||
|
model_strength: 0.8,
|
||||||
|
clip_strength: 0.6,
|
||||||
|
use_same_clip_strength: false,
|
||||||
|
repeat_count: 2,
|
||||||
|
repeat_used: 1,
|
||||||
|
is_paused: true
|
||||||
|
})
|
||||||
|
|
||||||
|
state.restoreFromConfig(config)
|
||||||
|
|
||||||
|
expect(state.currentIndex.value).toBe(3)
|
||||||
|
expect(state.totalCount.value).toBe(10)
|
||||||
|
expect(state.modelStrength.value).toBe(0.8)
|
||||||
|
expect(state.clipStrength.value).toBe(0.6)
|
||||||
|
expect(state.useCustomClipRange.value).toBe(true) // inverted from use_same_clip_strength
|
||||||
|
expect(state.repeatCount.value).toBe(2)
|
||||||
|
expect(state.repeatUsed.value).toBe(1)
|
||||||
|
expect(state.isPaused.value).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle missing optional fields with defaults', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
// Minimal config
|
||||||
|
state.restoreFromConfig({
|
||||||
|
current_index: 5,
|
||||||
|
total_count: 10,
|
||||||
|
pool_config_hash: '',
|
||||||
|
model_strength: 1.0,
|
||||||
|
clip_strength: 1.0,
|
||||||
|
use_same_clip_strength: true,
|
||||||
|
sort_by: 'filename',
|
||||||
|
current_lora_name: '',
|
||||||
|
current_lora_filename: '',
|
||||||
|
repeat_count: 1,
|
||||||
|
repeat_used: 0,
|
||||||
|
is_paused: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(state.currentIndex.value).toBe(5)
|
||||||
|
expect(state.repeatCount.value).toBe(1)
|
||||||
|
expect(state.isPaused.value).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not restore execution_index and next_index (transient values)', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
// Set execution indices
|
||||||
|
state.executionIndex.value = 2
|
||||||
|
state.nextIndex.value = 3
|
||||||
|
|
||||||
|
// Restore from config (these fields in config should be ignored)
|
||||||
|
state.restoreFromConfig(createMockCyclerConfig({
|
||||||
|
execution_index: 5,
|
||||||
|
next_index: 6
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Execution indices should remain unchanged
|
||||||
|
expect(state.executionIndex.value).toBe(2)
|
||||||
|
expect(state.nextIndex.value).toBe(3)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('buildConfig', () => {
|
||||||
|
it('should build config object from current state', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.currentIndex.value = 3
|
||||||
|
state.totalCount.value = 10
|
||||||
|
state.modelStrength.value = 0.8
|
||||||
|
state.repeatCount.value = 2
|
||||||
|
state.repeatUsed.value = 1
|
||||||
|
state.isPaused.value = true
|
||||||
|
|
||||||
|
const config = state.buildConfig()
|
||||||
|
|
||||||
|
expect(config.current_index).toBe(3)
|
||||||
|
expect(config.total_count).toBe(10)
|
||||||
|
expect(config.model_strength).toBe(0.8)
|
||||||
|
expect(config.repeat_count).toBe(2)
|
||||||
|
expect(config.repeat_used).toBe(1)
|
||||||
|
expect(config.is_paused).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('setIndex', () => {
|
||||||
|
it('should set index within valid range', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 10
|
||||||
|
|
||||||
|
state.setIndex(5)
|
||||||
|
expect(state.currentIndex.value).toBe(5)
|
||||||
|
|
||||||
|
state.setIndex(1)
|
||||||
|
expect(state.currentIndex.value).toBe(1)
|
||||||
|
|
||||||
|
state.setIndex(10)
|
||||||
|
expect(state.currentIndex.value).toBe(10)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not set index outside valid range', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 10
|
||||||
|
state.currentIndex.value = 5
|
||||||
|
|
||||||
|
state.setIndex(0)
|
||||||
|
expect(state.currentIndex.value).toBe(5) // unchanged
|
||||||
|
|
||||||
|
state.setIndex(11)
|
||||||
|
expect(state.currentIndex.value).toBe(5) // unchanged
|
||||||
|
|
||||||
|
state.setIndex(-1)
|
||||||
|
expect(state.currentIndex.value).toBe(5) // unchanged
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('resetIndex', () => {
|
||||||
|
it('should reset index to 1 and clear repeatUsed and displayRepeatUsed', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.currentIndex.value = 5
|
||||||
|
state.repeatUsed.value = 2
|
||||||
|
state.displayRepeatUsed.value = 2
|
||||||
|
state.isPaused.value = true
|
||||||
|
|
||||||
|
state.resetIndex()
|
||||||
|
|
||||||
|
expect(state.currentIndex.value).toBe(1)
|
||||||
|
expect(state.repeatUsed.value).toBe(0)
|
||||||
|
expect(state.displayRepeatUsed.value).toBe(0)
|
||||||
|
expect(state.isPaused.value).toBe(true) // isPaused should NOT be reset
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('togglePause', () => {
|
||||||
|
it('should toggle pause state', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
expect(state.isPaused.value).toBe(false)
|
||||||
|
|
||||||
|
state.togglePause()
|
||||||
|
expect(state.isPaused.value).toBe(true)
|
||||||
|
|
||||||
|
state.togglePause()
|
||||||
|
expect(state.isPaused.value).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('generateNextIndex', () => {
|
||||||
|
it('should shift indices correctly', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 5
|
||||||
|
state.currentIndex.value = 1
|
||||||
|
state.nextIndex.value = 2
|
||||||
|
|
||||||
|
// First call: executionIndex becomes 2 (previous nextIndex), nextIndex becomes 3
|
||||||
|
state.generateNextIndex()
|
||||||
|
|
||||||
|
expect(state.executionIndex.value).toBe(2)
|
||||||
|
expect(state.nextIndex.value).toBe(3)
|
||||||
|
|
||||||
|
// Second call: executionIndex becomes 3, nextIndex becomes 4
|
||||||
|
state.generateNextIndex()
|
||||||
|
|
||||||
|
expect(state.executionIndex.value).toBe(3)
|
||||||
|
expect(state.nextIndex.value).toBe(4)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should wrap index from totalCount to 1', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 5
|
||||||
|
state.nextIndex.value = 5 // At the last index
|
||||||
|
|
||||||
|
state.generateNextIndex()
|
||||||
|
|
||||||
|
expect(state.executionIndex.value).toBe(5)
|
||||||
|
expect(state.nextIndex.value).toBe(1) // Wrapped to 1
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use currentIndex when nextIndex is null', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 5
|
||||||
|
state.currentIndex.value = 3
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
state.generateNextIndex()
|
||||||
|
|
||||||
|
// executionIndex becomes previous nextIndex (null)
|
||||||
|
expect(state.executionIndex.value).toBeNull()
|
||||||
|
// nextIndex is calculated from currentIndex (3) -> 4
|
||||||
|
expect(state.nextIndex.value).toBe(4)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('initializeNextIndex', () => {
|
||||||
|
it('should initialize nextIndex to currentIndex + 1 when null', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 5
|
||||||
|
state.currentIndex.value = 1
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
state.initializeNextIndex()
|
||||||
|
|
||||||
|
expect(state.nextIndex.value).toBe(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should wrap nextIndex when currentIndex is at totalCount', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 5
|
||||||
|
state.currentIndex.value = 5
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
state.initializeNextIndex()
|
||||||
|
|
||||||
|
expect(state.nextIndex.value).toBe(1) // Wrapped
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not change nextIndex if already set', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 5
|
||||||
|
state.currentIndex.value = 1
|
||||||
|
state.nextIndex.value = 4
|
||||||
|
|
||||||
|
state.initializeNextIndex()
|
||||||
|
|
||||||
|
expect(state.nextIndex.value).toBe(4) // Unchanged
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Index Wrapping Edge Cases', () => {
|
||||||
|
it('should handle single item pool', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 1
|
||||||
|
state.currentIndex.value = 1
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
state.initializeNextIndex()
|
||||||
|
|
||||||
|
expect(state.nextIndex.value).toBe(1) // Wraps back to 1
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle zero total count gracefully', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.totalCount.value = 0
|
||||||
|
state.currentIndex.value = 1
|
||||||
|
state.nextIndex.value = null
|
||||||
|
|
||||||
|
state.initializeNextIndex()
|
||||||
|
|
||||||
|
// Should still calculate, even if totalCount is 0
|
||||||
|
expect(state.nextIndex.value).toBe(2) // No wrapping since totalCount <= 0
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('hashPoolConfig', () => {
|
||||||
|
it('should generate consistent hash for same config', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const config1 = createMockPoolConfig()
|
||||||
|
const config2 = createMockPoolConfig()
|
||||||
|
|
||||||
|
const hash1 = state.hashPoolConfig(config1)
|
||||||
|
const hash2 = state.hashPoolConfig(config2)
|
||||||
|
|
||||||
|
expect(hash1).toBe(hash2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate different hash for different configs', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const config1 = createMockPoolConfig({
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SD 1.5'],
|
||||||
|
tags: { include: [], exclude: [] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: { noCreditRequired: false, allowSelling: false }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const config2 = createMockPoolConfig({
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SDXL'],
|
||||||
|
tags: { include: [], exclude: [] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: { noCreditRequired: false, allowSelling: false }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const hash1 = state.hashPoolConfig(config1)
|
||||||
|
const hash2 = state.hashPoolConfig(config2)
|
||||||
|
|
||||||
|
expect(hash1).not.toBe(hash2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty string for null config', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
expect(state.hashPoolConfig(null)).toBe('')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty string for config without filters', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const config = { version: 1, preview: { matchCount: 0, lastUpdated: 0 } } as any
|
||||||
|
|
||||||
|
expect(state.hashPoolConfig(config)).toBe('')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Clip Strength Synchronization', () => {
|
||||||
|
it('should sync clipStrength with modelStrength when useCustomClipRange is false', async () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.useCustomClipRange.value = false
|
||||||
|
state.modelStrength.value = 0.5
|
||||||
|
|
||||||
|
// Wait for Vue reactivity
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(state.clipStrength.value).toBe(0.5)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not sync clipStrength when useCustomClipRange is true', async () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.useCustomClipRange.value = true
|
||||||
|
state.clipStrength.value = 0.7
|
||||||
|
state.modelStrength.value = 0.5
|
||||||
|
|
||||||
|
// clipStrength should remain unchanged
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(state.clipStrength.value).toBe(0.7)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Widget Value Synchronization', () => {
|
||||||
|
it('should update widget.value when state changes', async () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.currentIndex.value = 3
|
||||||
|
state.repeatCount.value = 2
|
||||||
|
|
||||||
|
// Wait for Vue reactivity
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(widget.value?.current_index).toBe(3)
|
||||||
|
expect(widget.value?.repeat_count).toBe(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Repeat Logic State', () => {
|
||||||
|
it('should track repeatUsed correctly', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.repeatCount.value = 3
|
||||||
|
expect(state.repeatUsed.value).toBe(0)
|
||||||
|
|
||||||
|
state.repeatUsed.value = 1
|
||||||
|
expect(state.repeatUsed.value).toBe(1)
|
||||||
|
|
||||||
|
state.repeatUsed.value = 3
|
||||||
|
expect(state.repeatUsed.value).toBe(3)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('fetchCyclerList', () => {
|
||||||
|
it('should call API and return lora list', async () => {
|
||||||
|
const mockLoras = [
|
||||||
|
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||||
|
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' }
|
||||||
|
]
|
||||||
|
|
||||||
|
setupFetchMock({ success: true, loras: mockLoras })
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const result = await state.fetchCyclerList(null)
|
||||||
|
|
||||||
|
expect(result).toEqual(mockLoras)
|
||||||
|
expect(state.isLoading.value).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include pool config filters in request', async () => {
|
||||||
|
const mockFetch = setupFetchMock({ success: true, loras: [] })
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const poolConfig = createMockPoolConfig()
|
||||||
|
await state.fetchCyclerList(poolConfig)
|
||||||
|
|
||||||
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
'/api/lm/loras/cycler-list',
|
||||||
|
expect.objectContaining({
|
||||||
|
method: 'POST',
|
||||||
|
body: expect.stringContaining('pool_config')
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should set isLoading during fetch', async () => {
|
||||||
|
let resolvePromise: (value: unknown) => void
|
||||||
|
const pendingPromise = new Promise(resolve => {
|
||||||
|
resolvePromise = resolve
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use mockFetch from setup instead of overriding global
|
||||||
|
const { mockFetch } = await import('../setup')
|
||||||
|
mockFetch.mockReset()
|
||||||
|
mockFetch.mockReturnValue(pendingPromise)
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
const fetchPromise = state.fetchCyclerList(null)
|
||||||
|
|
||||||
|
expect(state.isLoading.value).toBe(true)
|
||||||
|
|
||||||
|
// Resolve the fetch
|
||||||
|
resolvePromise!({
|
||||||
|
ok: true,
|
||||||
|
json: () => Promise.resolve({ success: true, loras: [] })
|
||||||
|
})
|
||||||
|
|
||||||
|
await fetchPromise
|
||||||
|
|
||||||
|
expect(state.isLoading.value).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('refreshList', () => {
|
||||||
|
it('should update totalCount from API response', async () => {
|
||||||
|
const mockLoras = [
|
||||||
|
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||||
|
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' },
|
||||||
|
{ file_name: 'lora3.safetensors', model_name: 'LoRA 3' }
|
||||||
|
]
|
||||||
|
|
||||||
|
// Reset and setup fresh mock
|
||||||
|
resetFetchMock()
|
||||||
|
setupFetchMock({ success: true, loras: mockLoras })
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
await state.refreshList(null)
|
||||||
|
|
||||||
|
expect(state.totalCount.value).toBe(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should reset index to 1 when pool config hash changes', async () => {
|
||||||
|
resetFetchMock()
|
||||||
|
setupFetchMock({ success: true, loras: [{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' }] })
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
// Set initial state
|
||||||
|
state.currentIndex.value = 5
|
||||||
|
state.poolConfigHash.value = 'old-hash'
|
||||||
|
|
||||||
|
// Refresh with new config (different hash)
|
||||||
|
const newConfig = createMockPoolConfig({
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SDXL'],
|
||||||
|
tags: { include: [], exclude: [] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: { noCreditRequired: false, allowSelling: false }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
await state.refreshList(newConfig)
|
||||||
|
|
||||||
|
expect(state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should clamp index when totalCount decreases', async () => {
|
||||||
|
// Setup mock first, then create state
|
||||||
|
resetFetchMock()
|
||||||
|
setupFetchMock({
|
||||||
|
success: true,
|
||||||
|
loras: [
|
||||||
|
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||||
|
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' },
|
||||||
|
{ file_name: 'lora3.safetensors', model_name: 'LoRA 3' }
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
// Set initial state with high index
|
||||||
|
state.currentIndex.value = 10
|
||||||
|
state.totalCount.value = 10
|
||||||
|
|
||||||
|
await state.refreshList(null)
|
||||||
|
|
||||||
|
expect(state.totalCount.value).toBe(3)
|
||||||
|
expect(state.currentIndex.value).toBe(3) // Clamped to max
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update currentLoraName and currentLoraFilename', async () => {
|
||||||
|
resetFetchMock()
|
||||||
|
setupFetchMock({
|
||||||
|
success: true,
|
||||||
|
loras: [
|
||||||
|
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||||
|
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' }
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
// Set totalCount first so setIndex works, then set index
|
||||||
|
state.totalCount.value = 2
|
||||||
|
state.currentIndex.value = 2
|
||||||
|
|
||||||
|
await state.refreshList(null)
|
||||||
|
|
||||||
|
expect(state.currentLoraFilename.value).toBe('lora2.safetensors')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty list gracefully', async () => {
|
||||||
|
resetFetchMock()
|
||||||
|
setupFetchMock({ success: true, loras: [] })
|
||||||
|
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
state.currentIndex.value = 5
|
||||||
|
state.totalCount.value = 5
|
||||||
|
|
||||||
|
await state.refreshList(null)
|
||||||
|
|
||||||
|
expect(state.totalCount.value).toBe(0)
|
||||||
|
// When totalCount is 0, Math.max(1, 0) = 1, but if currentIndex > totalCount it gets clamped to max(1, totalCount)
|
||||||
|
// Looking at the actual code: Math.max(1, totalCount) where totalCount=0 gives 1
|
||||||
|
expect(state.currentIndex.value).toBe(1)
|
||||||
|
expect(state.currentLoraName.value).toBe('')
|
||||||
|
expect(state.currentLoraFilename.value).toBe('')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isClipStrengthDisabled computed', () => {
|
||||||
|
it('should return true when useCustomClipRange is false', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.useCustomClipRange.value = false
|
||||||
|
expect(state.isClipStrengthDisabled.value).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false when useCustomClipRange is true', () => {
|
||||||
|
const widget = createMockWidget()
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
state.useCustomClipRange.value = true
|
||||||
|
expect(state.isClipStrengthDisabled.value).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
175
vue-widgets/tests/fixtures/mockConfigs.ts
vendored
Normal file
175
vue-widgets/tests/fixtures/mockConfigs.ts
vendored
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
/**
|
||||||
|
* Test fixtures for LoRA Cycler testing
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { CyclerConfig, LoraPoolConfig } from '@/composables/types'
|
||||||
|
import type { CyclerLoraItem } from '@/composables/useLoraCyclerState'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a default CyclerConfig for testing
|
||||||
|
*/
|
||||||
|
export function createMockCyclerConfig(overrides: Partial<CyclerConfig> = {}): CyclerConfig {
|
||||||
|
return {
|
||||||
|
current_index: 1,
|
||||||
|
total_count: 5,
|
||||||
|
pool_config_hash: '',
|
||||||
|
model_strength: 1.0,
|
||||||
|
clip_strength: 1.0,
|
||||||
|
use_same_clip_strength: true,
|
||||||
|
sort_by: 'filename',
|
||||||
|
current_lora_name: 'lora1.safetensors',
|
||||||
|
current_lora_filename: 'lora1.safetensors',
|
||||||
|
execution_index: null,
|
||||||
|
next_index: null,
|
||||||
|
repeat_count: 1,
|
||||||
|
repeat_used: 0,
|
||||||
|
is_paused: false,
|
||||||
|
...overrides
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock LoraPoolConfig for testing
|
||||||
|
*/
|
||||||
|
export function createMockPoolConfig(overrides: Partial<LoraPoolConfig> = {}): LoraPoolConfig {
|
||||||
|
return {
|
||||||
|
version: 1,
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SD 1.5'],
|
||||||
|
tags: { include: [], exclude: [] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: {
|
||||||
|
noCreditRequired: false,
|
||||||
|
allowSelling: false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
preview: { matchCount: 10, lastUpdated: Date.now() },
|
||||||
|
...overrides
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a list of mock LoRA items for testing
|
||||||
|
*/
|
||||||
|
export function createMockLoraList(count: number = 5): CyclerLoraItem[] {
|
||||||
|
return Array.from({ length: count }, (_, i) => ({
|
||||||
|
file_name: `lora${i + 1}.safetensors`,
|
||||||
|
model_name: `LoRA Model ${i + 1}`
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock widget object for testing useLoraCyclerState
|
||||||
|
*/
|
||||||
|
export function createMockWidget(initialValue?: CyclerConfig) {
|
||||||
|
return {
|
||||||
|
value: initialValue,
|
||||||
|
callback: undefined as ((v: CyclerConfig) => void) | undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock node object for testing component integration
|
||||||
|
*/
|
||||||
|
export function createMockNode(options: {
|
||||||
|
id?: number
|
||||||
|
poolConfig?: LoraPoolConfig | null
|
||||||
|
} = {}) {
|
||||||
|
const { id = 1, poolConfig = null } = options
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
inputs: [],
|
||||||
|
widgets: [],
|
||||||
|
graph: null,
|
||||||
|
getPoolConfig: () => poolConfig,
|
||||||
|
onExecuted: undefined as ((output: unknown) => void) | undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates mock execution output from the backend
|
||||||
|
*/
|
||||||
|
export function createMockExecutionOutput(options: {
|
||||||
|
nextIndex?: number
|
||||||
|
totalCount?: number
|
||||||
|
nextLoraName?: string
|
||||||
|
nextLoraFilename?: string
|
||||||
|
currentLoraName?: string
|
||||||
|
currentLoraFilename?: string
|
||||||
|
} = {}) {
|
||||||
|
const {
|
||||||
|
nextIndex = 2,
|
||||||
|
totalCount = 5,
|
||||||
|
nextLoraName = 'lora2.safetensors',
|
||||||
|
nextLoraFilename = 'lora2.safetensors',
|
||||||
|
currentLoraName = 'lora1.safetensors',
|
||||||
|
currentLoraFilename = 'lora1.safetensors'
|
||||||
|
} = options
|
||||||
|
|
||||||
|
return {
|
||||||
|
next_index: [nextIndex],
|
||||||
|
total_count: [totalCount],
|
||||||
|
next_lora_name: [nextLoraName],
|
||||||
|
next_lora_filename: [nextLoraFilename],
|
||||||
|
current_lora_name: [currentLoraName],
|
||||||
|
current_lora_filename: [currentLoraFilename]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sample LoRA lists for specific test scenarios
|
||||||
|
*/
|
||||||
|
export const SAMPLE_LORA_LISTS = {
|
||||||
|
// 3 LoRAs for simple cycling tests
|
||||||
|
small: createMockLoraList(3),
|
||||||
|
|
||||||
|
// 5 LoRAs for standard tests
|
||||||
|
medium: createMockLoraList(5),
|
||||||
|
|
||||||
|
// 10 LoRAs for larger tests
|
||||||
|
large: createMockLoraList(10),
|
||||||
|
|
||||||
|
// Empty list for edge case testing
|
||||||
|
empty: [] as CyclerLoraItem[],
|
||||||
|
|
||||||
|
// Single LoRA for edge case testing
|
||||||
|
single: createMockLoraList(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sample pool configs for testing
|
||||||
|
*/
|
||||||
|
export const SAMPLE_POOL_CONFIGS = {
|
||||||
|
// Default SD 1.5 filter
|
||||||
|
sd15: createMockPoolConfig({
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SD 1.5'],
|
||||||
|
tags: { include: [], exclude: [] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: { noCreditRequired: false, allowSelling: false }
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
// SDXL filter
|
||||||
|
sdxl: createMockPoolConfig({
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SDXL'],
|
||||||
|
tags: { include: [], exclude: [] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: { noCreditRequired: false, allowSelling: false }
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
// Filter with tags
|
||||||
|
withTags: createMockPoolConfig({
|
||||||
|
filters: {
|
||||||
|
baseModels: ['SD 1.5'],
|
||||||
|
tags: { include: ['anime', 'style'], exclude: ['realistic'] },
|
||||||
|
folders: { include: [], exclude: [] },
|
||||||
|
license: { noCreditRequired: false, allowSelling: false }
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
|
||||||
|
// Empty/null config
|
||||||
|
empty: null as LoraPoolConfig | null
|
||||||
|
}
|
||||||
885
vue-widgets/tests/integration/batchQueue.test.ts
Normal file
885
vue-widgets/tests/integration/batchQueue.test.ts
Normal file
@@ -0,0 +1,885 @@
|
|||||||
|
/**
|
||||||
|
* Integration tests for batch queue execution scenarios
|
||||||
|
*
|
||||||
|
* These tests simulate ComfyUI's execution modes to verify correct LoRA cycling behavior.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeEach, vi } from 'vitest'
|
||||||
|
import { useLoraCyclerState } from '@/composables/useLoraCyclerState'
|
||||||
|
import type { CyclerConfig } from '@/composables/types'
|
||||||
|
import {
|
||||||
|
createMockWidget,
|
||||||
|
createMockCyclerConfig,
|
||||||
|
createMockLoraList,
|
||||||
|
createMockPoolConfig
|
||||||
|
} from '../fixtures/mockConfigs'
|
||||||
|
import { setupFetchMock, resetFetchMock } from '../setup'
|
||||||
|
import { BatchQueueSimulator, IndexTracker } from '../utils/BatchQueueSimulator'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a test harness that mimics the LoraCyclerWidget's behavior
|
||||||
|
*/
|
||||||
|
function createTestHarness(options: {
|
||||||
|
totalCount?: number
|
||||||
|
initialIndex?: number
|
||||||
|
repeatCount?: number
|
||||||
|
isPaused?: boolean
|
||||||
|
} = {}) {
|
||||||
|
const {
|
||||||
|
totalCount = 5,
|
||||||
|
initialIndex = 1,
|
||||||
|
repeatCount = 1,
|
||||||
|
isPaused = false
|
||||||
|
} = options
|
||||||
|
|
||||||
|
const widget = createMockWidget() as any
|
||||||
|
const state = useLoraCyclerState(widget)
|
||||||
|
|
||||||
|
// Initialize state
|
||||||
|
state.totalCount.value = totalCount
|
||||||
|
state.currentIndex.value = initialIndex
|
||||||
|
state.repeatCount.value = repeatCount
|
||||||
|
state.isPaused.value = isPaused
|
||||||
|
|
||||||
|
// Track if first execution
|
||||||
|
const HAS_EXECUTED = Symbol('HAS_EXECUTED')
|
||||||
|
widget[HAS_EXECUTED] = false
|
||||||
|
|
||||||
|
// Execution queue for batch synchronization
|
||||||
|
interface ExecutionContext {
|
||||||
|
isPaused: boolean
|
||||||
|
repeatUsed: number
|
||||||
|
repeatCount: number
|
||||||
|
shouldAdvanceDisplay: boolean
|
||||||
|
displayRepeatUsed: number // Value to show in UI after completion
|
||||||
|
}
|
||||||
|
const executionQueue: ExecutionContext[] = []
|
||||||
|
|
||||||
|
// beforeQueued hook (mirrors LoraCyclerWidget.vue logic)
|
||||||
|
widget.beforeQueued = () => {
|
||||||
|
if (state.isPaused.value) {
|
||||||
|
executionQueue.push({
|
||||||
|
isPaused: true,
|
||||||
|
repeatUsed: state.repeatUsed.value,
|
||||||
|
repeatCount: state.repeatCount.value,
|
||||||
|
shouldAdvanceDisplay: false,
|
||||||
|
displayRepeatUsed: state.displayRepeatUsed.value // Keep current display value when paused
|
||||||
|
})
|
||||||
|
// CRITICAL: Clear execution_index when paused to force backend to use current_index
|
||||||
|
const pausedConfig = state.buildConfig()
|
||||||
|
pausedConfig.execution_index = null
|
||||||
|
widget.value = pausedConfig
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (widget[HAS_EXECUTED]) {
|
||||||
|
if (state.repeatUsed.value < state.repeatCount.value) {
|
||||||
|
state.repeatUsed.value++
|
||||||
|
} else {
|
||||||
|
state.repeatUsed.value = 1
|
||||||
|
state.generateNextIndex()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state.repeatUsed.value = 1
|
||||||
|
state.initializeNextIndex()
|
||||||
|
widget[HAS_EXECUTED] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
const shouldAdvanceDisplay = state.repeatUsed.value >= state.repeatCount.value
|
||||||
|
// Calculate the display value to show after this execution completes
|
||||||
|
// When advancing to a new LoRA: reset to 0 (fresh start for new LoRA)
|
||||||
|
// When repeating same LoRA: show current repeat step
|
||||||
|
const displayRepeatUsed = shouldAdvanceDisplay ? 0 : state.repeatUsed.value
|
||||||
|
|
||||||
|
executionQueue.push({
|
||||||
|
isPaused: false,
|
||||||
|
repeatUsed: state.repeatUsed.value,
|
||||||
|
repeatCount: state.repeatCount.value,
|
||||||
|
shouldAdvanceDisplay,
|
||||||
|
displayRepeatUsed
|
||||||
|
})
|
||||||
|
|
||||||
|
widget.value = state.buildConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock node with onExecuted
|
||||||
|
const node = {
|
||||||
|
id: 1,
|
||||||
|
onExecuted: (output: any) => {
|
||||||
|
const context = executionQueue.shift()
|
||||||
|
|
||||||
|
const shouldAdvanceDisplay = context
|
||||||
|
? context.shouldAdvanceDisplay
|
||||||
|
: (!state.isPaused.value && state.repeatUsed.value >= state.repeatCount.value)
|
||||||
|
|
||||||
|
// Update displayRepeatUsed (deferred like index updates)
|
||||||
|
if (context) {
|
||||||
|
state.displayRepeatUsed.value = context.displayRepeatUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldAdvanceDisplay && output?.next_index !== undefined) {
|
||||||
|
const val = Array.isArray(output.next_index) ? output.next_index[0] : output.next_index
|
||||||
|
state.currentIndex.value = val
|
||||||
|
}
|
||||||
|
if (output?.total_count !== undefined) {
|
||||||
|
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
|
||||||
|
state.totalCount.value = val
|
||||||
|
}
|
||||||
|
if (shouldAdvanceDisplay) {
|
||||||
|
if (output?.next_lora_name !== undefined) {
|
||||||
|
const val = Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name
|
||||||
|
state.currentLoraName.value = val
|
||||||
|
}
|
||||||
|
if (output?.next_lora_filename !== undefined) {
|
||||||
|
const val = Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename
|
||||||
|
state.currentLoraFilename.value = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset execution state (mimics manual index change)
|
||||||
|
const resetExecutionState = () => {
|
||||||
|
widget[HAS_EXECUTED] = false
|
||||||
|
state.executionIndex.value = null
|
||||||
|
state.nextIndex.value = null
|
||||||
|
executionQueue.length = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
widget,
|
||||||
|
state,
|
||||||
|
node,
|
||||||
|
executionQueue,
|
||||||
|
resetExecutionState,
|
||||||
|
getConfig: () => state.buildConfig(),
|
||||||
|
HAS_EXECUTED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('Batch Queue Integration Tests', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
resetFetchMock()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Basic Cycling', () => {
|
||||||
|
it('should cycle through N LoRAs in batch of N (batch queue mode)', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
// Simulate batch queue of 3 prompts
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
3,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// After cycling through all 3, currentIndex should wrap back to 1
|
||||||
|
// First execution: index 1, next becomes 2
|
||||||
|
// Second execution: index 2, next becomes 3
|
||||||
|
// Third execution: index 3, next becomes 1
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should cycle through N LoRAs in batch of N (sequential mode)', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
// Simulate sequential execution of 3 prompts
|
||||||
|
await simulator.runSequential(
|
||||||
|
3,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Same result as batch mode
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle partial cycle (batch of 2 in pool of 5)', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5, initialIndex: 1 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
2,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// After 2 executions starting from 1: 1 -> 2 -> 3
|
||||||
|
expect(harness.state.currentIndex.value).toBe(3)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Repeat Functionality', () => {
|
||||||
|
it('should repeat each LoRA repeatCount times', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
// With repeatCount=2, need 6 executions to cycle through 3 LoRAs
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
6,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Should have cycled back to beginning
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should track repeatUsed correctly during batch', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, repeatCount: 3 })
|
||||||
|
|
||||||
|
// First beforeQueued: repeatUsed = 1
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
expect(harness.state.repeatUsed.value).toBe(1)
|
||||||
|
|
||||||
|
// Second beforeQueued: repeatUsed = 2
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
expect(harness.state.repeatUsed.value).toBe(2)
|
||||||
|
|
||||||
|
// Third beforeQueued: repeatUsed = 3 (will advance on next)
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
expect(harness.state.repeatUsed.value).toBe(3)
|
||||||
|
|
||||||
|
// Fourth beforeQueued: repeatUsed resets to 1, index advances
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
expect(harness.state.repeatUsed.value).toBe(1)
|
||||||
|
expect(harness.state.nextIndex.value).toBe(3) // Advanced from 2 to 3
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not advance display until repeat cycle completes', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5, repeatCount: 2 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// First execution: repeatUsed=1 < repeatCount=2, shouldAdvanceDisplay=false
|
||||||
|
// Second execution: repeatUsed=2 >= repeatCount=2, shouldAdvanceDisplay=true
|
||||||
|
|
||||||
|
const indexHistory: number[] = []
|
||||||
|
|
||||||
|
// Override onExecuted to track index changes
|
||||||
|
const originalOnExecuted = harness.node.onExecuted
|
||||||
|
harness.node.onExecuted = (output: any) => {
|
||||||
|
originalOnExecuted(output)
|
||||||
|
indexHistory.push(harness.state.currentIndex.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
4,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Index should only change on 2nd and 4th execution
|
||||||
|
// Starting at 1: stay 1, advance to 2, stay 2, advance to 3
|
||||||
|
expect(indexHistory).toEqual([1, 2, 2, 3])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should defer displayRepeatUsed updates until workflow completion', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, repeatCount: 3 })
|
||||||
|
|
||||||
|
// Initial state
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(0)
|
||||||
|
|
||||||
|
// Queue 3 executions in batch mode (all beforeQueued before any onExecuted)
|
||||||
|
harness.widget.beforeQueued() // repeatUsed = 1
|
||||||
|
harness.widget.beforeQueued() // repeatUsed = 2
|
||||||
|
harness.widget.beforeQueued() // repeatUsed = 3
|
||||||
|
|
||||||
|
// displayRepeatUsed should NOT have changed yet (still 0)
|
||||||
|
// because no onExecuted has been called
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(0)
|
||||||
|
|
||||||
|
// Now simulate workflow completions
|
||||||
|
harness.node.onExecuted({ next_index: 1 })
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||||
|
|
||||||
|
harness.node.onExecuted({ next_index: 1 })
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(2)
|
||||||
|
|
||||||
|
harness.node.onExecuted({ next_index: 2 })
|
||||||
|
// After completing repeat cycle, displayRepeatUsed resets to 0
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should reset displayRepeatUsed to 0 when advancing to new LoRA', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
const displayHistory: number[] = []
|
||||||
|
|
||||||
|
const originalOnExecuted = harness.node.onExecuted
|
||||||
|
harness.node.onExecuted = (output: any) => {
|
||||||
|
originalOnExecuted(output)
|
||||||
|
displayHistory.push(harness.state.displayRepeatUsed.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run 4 executions: 2 repeats of LoRA 1, 2 repeats of LoRA 2
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
4,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// displayRepeatUsed should show:
|
||||||
|
// 1st exec: 1 (first repeat of LoRA 1)
|
||||||
|
// 2nd exec: 0 (complete, reset for next LoRA)
|
||||||
|
// 3rd exec: 1 (first repeat of LoRA 2)
|
||||||
|
// 4th exec: 0 (complete, reset for next LoRA)
|
||||||
|
expect(displayHistory).toEqual([1, 0, 1, 0])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should show current repeat step when not advancing', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, repeatCount: 4 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
const displayHistory: number[] = []
|
||||||
|
|
||||||
|
const originalOnExecuted = harness.node.onExecuted
|
||||||
|
harness.node.onExecuted = (output: any) => {
|
||||||
|
originalOnExecuted(output)
|
||||||
|
displayHistory.push(harness.state.displayRepeatUsed.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run 4 executions: all 4 repeats of the same LoRA
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
4,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// displayRepeatUsed should show:
|
||||||
|
// 1st exec: 1 (repeat 1/4, not advancing)
|
||||||
|
// 2nd exec: 2 (repeat 2/4, not advancing)
|
||||||
|
// 3rd exec: 3 (repeat 3/4, not advancing)
|
||||||
|
// 4th exec: 0 (repeat 4/4, complete, reset for next LoRA)
|
||||||
|
expect(displayHistory).toEqual([1, 2, 3, 0])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Pause Functionality', () => {
|
||||||
|
it('should maintain index when paused', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5, isPaused: true })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
3,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Index should not advance when paused
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not count paused executions toward repeat limit', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5, repeatCount: 2 })
|
||||||
|
|
||||||
|
// Run 2 executions while paused
|
||||||
|
harness.state.isPaused.value = true
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
|
||||||
|
// repeatUsed should still be 0 (paused executions don't count)
|
||||||
|
expect(harness.state.repeatUsed.value).toBe(0)
|
||||||
|
|
||||||
|
// Unpause and run
|
||||||
|
harness.state.isPaused.value = false
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
expect(harness.state.repeatUsed.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve displayRepeatUsed when paused', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5, repeatCount: 3 })
|
||||||
|
|
||||||
|
// Run one execution to set displayRepeatUsed
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.node.onExecuted({ next_index: 1 })
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||||
|
|
||||||
|
// Pause
|
||||||
|
harness.state.isPaused.value = true
|
||||||
|
|
||||||
|
// Queue and execute while paused
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.node.onExecuted({ next_index: 1 })
|
||||||
|
|
||||||
|
// displayRepeatUsed should remain at 1 (paused executions don't change it)
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||||
|
|
||||||
|
// Queue another paused execution
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.node.onExecuted({ next_index: 1 })
|
||||||
|
|
||||||
|
// Still should be 1
|
||||||
|
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use same LoRA when pause is toggled mid-batch', async () => {
|
||||||
|
// This tests the critical bug scenario:
|
||||||
|
// 1. User queues multiple prompts (not paused)
|
||||||
|
// 2. All beforeQueued calls complete, each advancing execution_index
|
||||||
|
// 3. User clicks pause
|
||||||
|
// 4. onExecuted starts firing - paused executions should use current_index, not execution_index
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Queue first prompt (not paused) - this sets up execution_index
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
const config1 = harness.getConfig()
|
||||||
|
expect(config1.execution_index).toBeNull() // First execution uses current_index
|
||||||
|
|
||||||
|
// User clicks pause mid-batch
|
||||||
|
harness.state.isPaused.value = true
|
||||||
|
|
||||||
|
// Queue subsequent prompts while paused
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
const config2 = harness.getConfig()
|
||||||
|
// CRITICAL: execution_index should be null when paused to force backend to use current_index
|
||||||
|
expect(config2.execution_index).toBeNull()
|
||||||
|
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
const config3 = harness.getConfig()
|
||||||
|
expect(config3.execution_index).toBeNull()
|
||||||
|
|
||||||
|
// Verify execution queue has correct context
|
||||||
|
expect(harness.executionQueue.length).toBe(3)
|
||||||
|
expect(harness.executionQueue[0].isPaused).toBe(false)
|
||||||
|
expect(harness.executionQueue[1].isPaused).toBe(true)
|
||||||
|
expect(harness.executionQueue[2].isPaused).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should have null execution_index in widget.value when paused even after non-paused queues', async () => {
|
||||||
|
// More detailed test for the execution_index clearing behavior
|
||||||
|
// This tests that widget.value (what backend receives) has null execution_index
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Queue 3 prompts while not paused
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
|
||||||
|
// Verify execution_index was set by non-paused queues in widget.value
|
||||||
|
expect(harness.widget.value.execution_index).not.toBeNull()
|
||||||
|
|
||||||
|
// User pauses
|
||||||
|
harness.state.isPaused.value = true
|
||||||
|
|
||||||
|
// Queue while paused - should clear execution_index in widget.value
|
||||||
|
// This is the value that gets sent to the backend
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
expect(harness.widget.value.execution_index).toBeNull()
|
||||||
|
|
||||||
|
// State's executionIndex may still have the old value (that's fine)
|
||||||
|
// What matters is widget.value which is what the backend uses
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should have hasQueuedPrompts true when execution queue has items', async () => {
|
||||||
|
// This tests the pause button disabled state
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Initially no queued prompts
|
||||||
|
expect(harness.executionQueue.length).toBe(0)
|
||||||
|
|
||||||
|
// Queue some prompts
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
|
||||||
|
// Execution queue should have items
|
||||||
|
expect(harness.executionQueue.length).toBe(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should have empty execution queue after all executions complete', async () => {
|
||||||
|
// This tests that pause button becomes enabled after executions complete
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Run batch queue execution
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
3,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// After all executions, queue should be empty
|
||||||
|
expect(harness.executionQueue.length).toBe(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should resume cycling after unpause', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, initialIndex: 2 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
// Execute once while not paused
|
||||||
|
await simulator.runSingle(
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pause
|
||||||
|
harness.state.isPaused.value = true
|
||||||
|
|
||||||
|
// Execute twice while paused
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
2,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Unpause and execute
|
||||||
|
harness.state.isPaused.value = false
|
||||||
|
|
||||||
|
await simulator.runSingle(
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Should continue from where it left off (index 3 -> 1)
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Manual Index Change', () => {
|
||||||
|
it('should reset execution state on manual index change', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Execute a few times
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
|
||||||
|
expect(harness.widget[harness.HAS_EXECUTED]).toBe(true)
|
||||||
|
expect(harness.executionQueue.length).toBe(2)
|
||||||
|
|
||||||
|
// User manually changes index (mimics handleIndexUpdate)
|
||||||
|
harness.resetExecutionState()
|
||||||
|
harness.state.setIndex(4)
|
||||||
|
|
||||||
|
expect(harness.widget[harness.HAS_EXECUTED]).toBe(false)
|
||||||
|
expect(harness.state.executionIndex.value).toBeNull()
|
||||||
|
expect(harness.state.nextIndex.value).toBeNull()
|
||||||
|
expect(harness.executionQueue.length).toBe(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should start fresh cycle from manual index', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Execute twice starting from 1
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
2,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(harness.state.currentIndex.value).toBe(3)
|
||||||
|
|
||||||
|
// User manually sets index to 1
|
||||||
|
harness.resetExecutionState()
|
||||||
|
harness.state.setIndex(1)
|
||||||
|
|
||||||
|
// Execute again - should start fresh from 1
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
2,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(harness.state.currentIndex.value).toBe(3)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Execution Queue Mismatch', () => {
|
||||||
|
it('should handle interrupted execution (queue > executed)', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Queue 5 but only execute 2 (simulates cancel)
|
||||||
|
await simulator.runInterrupted(
|
||||||
|
5, // queued
|
||||||
|
2, // executed
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// 3 contexts remain in queue
|
||||||
|
expect(harness.executionQueue.length).toBe(3)
|
||||||
|
|
||||||
|
// Index should reflect only the 2 executions that completed
|
||||||
|
expect(harness.state.currentIndex.value).toBe(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should recover from mismatch on next manual index change', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Create mismatch
|
||||||
|
await simulator.runInterrupted(
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(harness.executionQueue.length).toBe(3)
|
||||||
|
|
||||||
|
// Manual index change clears queue
|
||||||
|
harness.resetExecutionState()
|
||||||
|
harness.state.setIndex(1)
|
||||||
|
|
||||||
|
expect(harness.executionQueue.length).toBe(0)
|
||||||
|
|
||||||
|
// Can execute normally again
|
||||||
|
await simulator.runSingle(
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(harness.state.currentIndex.value).toBe(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Edge Cases', () => {
|
||||||
|
it('should handle single item pool', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 1 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 1 })
|
||||||
|
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
3,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Should always stay at index 1
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty pool gracefully', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 0 })
|
||||||
|
|
||||||
|
// beforeQueued should still work without errors
|
||||||
|
expect(() => harness.widget.beforeQueued()).not.toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle rapid sequential executions', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Run 20 sequential executions
|
||||||
|
await simulator.runSequential(
|
||||||
|
20,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// 20 % 5 = 0, so should wrap back to 1
|
||||||
|
// But first execution uses index 1, so after 20 executions we're at 21 % 5 = 1
|
||||||
|
expect(harness.state.currentIndex.value).toBe(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve state consistency across many cycles', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
// Run 100 executions in batches
|
||||||
|
for (let batch = 0; batch < 10; batch++) {
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
10,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify state is still valid
|
||||||
|
expect(harness.state.currentIndex.value).toBeGreaterThanOrEqual(1)
|
||||||
|
expect(harness.state.currentIndex.value).toBeLessThanOrEqual(3)
|
||||||
|
expect(harness.state.repeatUsed.value).toBeGreaterThanOrEqual(1)
|
||||||
|
expect(harness.state.repeatUsed.value).toBeLessThanOrEqual(2)
|
||||||
|
expect(harness.executionQueue.length).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Invariant Assertions', () => {
|
||||||
|
it('should always have valid index (1 <= currentIndex <= totalCount)', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
const checkInvariant = () => {
|
||||||
|
const { currentIndex, totalCount } = harness.state
|
||||||
|
if (totalCount.value > 0) {
|
||||||
|
expect(currentIndex.value).toBeGreaterThanOrEqual(1)
|
||||||
|
expect(currentIndex.value).toBeLessThanOrEqual(totalCount.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override onExecuted to check invariant after each execution
|
||||||
|
const originalOnExecuted = harness.node.onExecuted
|
||||||
|
harness.node.onExecuted = (output: any) => {
|
||||||
|
originalOnExecuted(output)
|
||||||
|
checkInvariant()
|
||||||
|
}
|
||||||
|
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
20,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should always have repeatUsed <= repeatCount', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5, repeatCount: 3 })
|
||||||
|
|
||||||
|
const checkInvariant = () => {
|
||||||
|
expect(harness.state.repeatUsed.value).toBeLessThanOrEqual(harness.state.repeatCount.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check after each beforeQueued
|
||||||
|
for (let i = 0; i < 20; i++) {
|
||||||
|
harness.widget.beforeQueued()
|
||||||
|
checkInvariant()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should consume all execution contexts (queue empty after matching executions)', async () => {
|
||||||
|
const harness = createTestHarness({ totalCount: 5 })
|
||||||
|
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
await simulator.runBatchQueue(
|
||||||
|
7,
|
||||||
|
{
|
||||||
|
beforeQueued: () => harness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => harness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => harness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(harness.executionQueue.length).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Batch vs Sequential Mode Equivalence', () => {
|
||||||
|
it('should produce same final state in both modes (basic cycle)', async () => {
|
||||||
|
// Create two identical harnesses
|
||||||
|
const batchHarness = createTestHarness({ totalCount: 5 })
|
||||||
|
const seqHarness = createTestHarness({ totalCount: 5 })
|
||||||
|
|
||||||
|
const batchSimulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
const seqSimulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||||
|
|
||||||
|
// Run same number of executions in different modes
|
||||||
|
await batchSimulator.runBatchQueue(
|
||||||
|
7,
|
||||||
|
{
|
||||||
|
beforeQueued: () => batchHarness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => batchHarness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => batchHarness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
await seqSimulator.runSequential(
|
||||||
|
7,
|
||||||
|
{
|
||||||
|
beforeQueued: () => seqHarness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => seqHarness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => seqHarness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Final state should be identical
|
||||||
|
expect(batchHarness.state.currentIndex.value).toBe(seqHarness.state.currentIndex.value)
|
||||||
|
expect(batchHarness.state.repeatUsed.value).toBe(seqHarness.state.repeatUsed.value)
|
||||||
|
expect(batchHarness.state.displayRepeatUsed.value).toBe(seqHarness.state.displayRepeatUsed.value)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should produce same final state in both modes (with repeat)', async () => {
|
||||||
|
const batchHarness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||||
|
const seqHarness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||||
|
|
||||||
|
const batchSimulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
const seqSimulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||||
|
|
||||||
|
await batchSimulator.runBatchQueue(
|
||||||
|
10,
|
||||||
|
{
|
||||||
|
beforeQueued: () => batchHarness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => batchHarness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => batchHarness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
await seqSimulator.runSequential(
|
||||||
|
10,
|
||||||
|
{
|
||||||
|
beforeQueued: () => seqHarness.widget.beforeQueued(),
|
||||||
|
onExecuted: (output) => seqHarness.node.onExecuted(output)
|
||||||
|
},
|
||||||
|
() => seqHarness.getConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(batchHarness.state.currentIndex.value).toBe(seqHarness.state.currentIndex.value)
|
||||||
|
expect(batchHarness.state.repeatUsed.value).toBe(seqHarness.state.repeatUsed.value)
|
||||||
|
expect(batchHarness.state.displayRepeatUsed.value).toBe(seqHarness.state.displayRepeatUsed.value)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
75
vue-widgets/tests/setup.ts
Normal file
75
vue-widgets/tests/setup.ts
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
/**
|
||||||
|
* Vitest test setup file
|
||||||
|
* Configures global mocks for ComfyUI modules and browser APIs
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { vi } from 'vitest'
|
||||||
|
|
||||||
|
// Mock ComfyUI app module
|
||||||
|
vi.mock('../../../scripts/app.js', () => ({
|
||||||
|
app: {
|
||||||
|
graph: {
|
||||||
|
_nodes: []
|
||||||
|
},
|
||||||
|
registerExtension: vi.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Mock ComfyUI loras_widget module
|
||||||
|
vi.mock('../loras_widget.js', () => ({
|
||||||
|
addLoraCard: vi.fn(),
|
||||||
|
removeLoraCard: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Mock ComfyUI autocomplete module
|
||||||
|
vi.mock('../autocomplete.js', () => ({
|
||||||
|
setupAutocomplete: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Global fetch mock - exported so tests can access it directly
|
||||||
|
export const mockFetch = vi.fn()
|
||||||
|
vi.stubGlobal('fetch', mockFetch)
|
||||||
|
|
||||||
|
// Helper to reset fetch mock between tests
|
||||||
|
export function resetFetchMock() {
|
||||||
|
mockFetch.mockReset()
|
||||||
|
// Re-stub global to ensure it's the same mock
|
||||||
|
vi.stubGlobal('fetch', mockFetch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to setup fetch mock with default success response
|
||||||
|
export function setupFetchMock(response: unknown = { success: true, loras: [] }) {
|
||||||
|
// Ensure we're using the same mock
|
||||||
|
mockFetch.mockReset()
|
||||||
|
mockFetch.mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: () => Promise.resolve(response)
|
||||||
|
})
|
||||||
|
vi.stubGlobal('fetch', mockFetch)
|
||||||
|
return mockFetch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to setup fetch mock with error response
|
||||||
|
export function setupFetchErrorMock(error: string = 'Network error') {
|
||||||
|
mockFetch.mockReset()
|
||||||
|
mockFetch.mockRejectedValue(new Error(error))
|
||||||
|
vi.stubGlobal('fetch', mockFetch)
|
||||||
|
return mockFetch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock btoa for hashing (jsdom should have this, but just in case)
|
||||||
|
if (typeof global.btoa === 'undefined') {
|
||||||
|
vi.stubGlobal('btoa', (str: string) => Buffer.from(str).toString('base64'))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock console methods to reduce noise in tests
|
||||||
|
vi.spyOn(console, 'log').mockImplementation(() => {})
|
||||||
|
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||||
|
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||||
|
|
||||||
|
// Re-enable console for debugging when needed
|
||||||
|
export function enableConsole() {
|
||||||
|
vi.spyOn(console, 'log').mockRestore()
|
||||||
|
vi.spyOn(console, 'error').mockRestore()
|
||||||
|
vi.spyOn(console, 'warn').mockRestore()
|
||||||
|
}
|
||||||
230
vue-widgets/tests/utils/BatchQueueSimulator.ts
Normal file
230
vue-widgets/tests/utils/BatchQueueSimulator.ts
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
/**
|
||||||
|
* BatchQueueSimulator - Simulates ComfyUI's two execution modes
|
||||||
|
*
|
||||||
|
* ComfyUI has two distinct execution patterns:
|
||||||
|
* 1. Batch Queue Mode: ALL beforeQueued calls happen BEFORE any onExecuted calls
|
||||||
|
* 2. Sequential Mode: beforeQueued and onExecuted interleave for each prompt
|
||||||
|
*
|
||||||
|
* This simulator helps test how the widget behaves in both modes.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { CyclerConfig } from '@/composables/types'
|
||||||
|
|
||||||
|
export interface ExecutionHooks {
|
||||||
|
/** Called when a prompt is queued (before execution) */
|
||||||
|
beforeQueued: () => void
|
||||||
|
/** Called when execution completes with output */
|
||||||
|
onExecuted: (output: unknown) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SimulatorOptions {
|
||||||
|
/** Total number of LoRAs in the pool */
|
||||||
|
totalCount: number
|
||||||
|
/** Function to generate output for each execution */
|
||||||
|
generateOutput?: (executionIndex: number, config: CyclerConfig) => unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates execution output based on the current state
|
||||||
|
*/
|
||||||
|
function defaultGenerateOutput(executionIndex: number, config: CyclerConfig) {
|
||||||
|
// Calculate what the next index would be after this execution
|
||||||
|
let nextIdx = (config.execution_index ?? config.current_index) + 1
|
||||||
|
if (nextIdx > config.total_count) {
|
||||||
|
nextIdx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
next_index: [nextIdx],
|
||||||
|
total_count: [config.total_count],
|
||||||
|
next_lora_name: [`lora${nextIdx}.safetensors`],
|
||||||
|
next_lora_filename: [`lora${nextIdx}.safetensors`],
|
||||||
|
current_lora_name: [`lora${config.execution_index ?? config.current_index}.safetensors`],
|
||||||
|
current_lora_filename: [`lora${config.execution_index ?? config.current_index}.safetensors`]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class BatchQueueSimulator {
|
||||||
|
private executionCount = 0
|
||||||
|
private options: Required<SimulatorOptions>
|
||||||
|
|
||||||
|
constructor(options: SimulatorOptions) {
|
||||||
|
this.options = {
|
||||||
|
totalCount: options.totalCount,
|
||||||
|
generateOutput: options.generateOutput ?? defaultGenerateOutput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset the simulator state
|
||||||
|
*/
|
||||||
|
reset() {
|
||||||
|
this.executionCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simulates Batch Queue Mode execution
|
||||||
|
*
|
||||||
|
* In this mode, ComfyUI queues multiple prompts at once:
|
||||||
|
* - ALL beforeQueued() calls happen first (for all prompts in the batch)
|
||||||
|
* - THEN all onExecuted() calls happen (as each prompt completes)
|
||||||
|
*
|
||||||
|
* This is the mode used when queueing multiple prompts from the UI.
|
||||||
|
*
|
||||||
|
* @param count Number of prompts to simulate
|
||||||
|
* @param hooks The widget's execution hooks
|
||||||
|
* @param getConfig Function to get current widget config state
|
||||||
|
*/
|
||||||
|
async runBatchQueue(
|
||||||
|
count: number,
|
||||||
|
hooks: ExecutionHooks,
|
||||||
|
getConfig: () => CyclerConfig
|
||||||
|
): Promise<void> {
|
||||||
|
// Phase 1: All beforeQueued calls (snapshot configs)
|
||||||
|
const snapshotConfigs: CyclerConfig[] = []
|
||||||
|
|
||||||
|
for (let i = 0; i < count; i++) {
|
||||||
|
hooks.beforeQueued()
|
||||||
|
// Snapshot the config after beforeQueued updates it
|
||||||
|
snapshotConfigs.push({ ...getConfig() })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: All onExecuted calls (in order)
|
||||||
|
for (let i = 0; i < count; i++) {
|
||||||
|
const config = snapshotConfigs[i]
|
||||||
|
const output = this.options.generateOutput(this.executionCount, config)
|
||||||
|
hooks.onExecuted(output)
|
||||||
|
this.executionCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simulates Sequential Mode execution
|
||||||
|
*
|
||||||
|
* In this mode, execution is one-at-a-time:
|
||||||
|
* - beforeQueued() is called
|
||||||
|
* - onExecuted() is called
|
||||||
|
* - Then the next prompt's beforeQueued() is called
|
||||||
|
* - And so on...
|
||||||
|
*
|
||||||
|
* This is the mode used in API-driven execution or single prompt queuing.
|
||||||
|
*
|
||||||
|
* @param count Number of prompts to simulate
|
||||||
|
* @param hooks The widget's execution hooks
|
||||||
|
* @param getConfig Function to get current widget config state
|
||||||
|
*/
|
||||||
|
async runSequential(
|
||||||
|
count: number,
|
||||||
|
hooks: ExecutionHooks,
|
||||||
|
getConfig: () => CyclerConfig
|
||||||
|
): Promise<void> {
|
||||||
|
for (let i = 0; i < count; i++) {
|
||||||
|
// Queue the prompt
|
||||||
|
hooks.beforeQueued()
|
||||||
|
const config = { ...getConfig() }
|
||||||
|
|
||||||
|
// Execute it immediately
|
||||||
|
const output = this.options.generateOutput(this.executionCount, config)
|
||||||
|
hooks.onExecuted(output)
|
||||||
|
this.executionCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simulates a single execution (queue + execute)
|
||||||
|
*/
|
||||||
|
async runSingle(
|
||||||
|
hooks: ExecutionHooks,
|
||||||
|
getConfig: () => CyclerConfig
|
||||||
|
): Promise<void> {
|
||||||
|
return this.runSequential(1, hooks, getConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simulates interrupted execution (some beforeQueued calls without matching onExecuted)
|
||||||
|
*
|
||||||
|
* This can happen if the user cancels execution mid-batch.
|
||||||
|
*
|
||||||
|
* @param queuedCount Number of prompts queued (beforeQueued called)
|
||||||
|
* @param executedCount Number of prompts that actually executed
|
||||||
|
* @param hooks The widget's execution hooks
|
||||||
|
* @param getConfig Function to get current widget config state
|
||||||
|
*/
|
||||||
|
async runInterrupted(
|
||||||
|
queuedCount: number,
|
||||||
|
executedCount: number,
|
||||||
|
hooks: ExecutionHooks,
|
||||||
|
getConfig: () => CyclerConfig
|
||||||
|
): Promise<void> {
|
||||||
|
if (executedCount > queuedCount) {
|
||||||
|
throw new Error('executedCount cannot be greater than queuedCount')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1: All beforeQueued calls
|
||||||
|
const snapshotConfigs: CyclerConfig[] = []
|
||||||
|
for (let i = 0; i < queuedCount; i++) {
|
||||||
|
hooks.beforeQueued()
|
||||||
|
snapshotConfigs.push({ ...getConfig() })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: Only some onExecuted calls
|
||||||
|
for (let i = 0; i < executedCount; i++) {
|
||||||
|
const config = snapshotConfigs[i]
|
||||||
|
const output = this.options.generateOutput(this.executionCount, config)
|
||||||
|
hooks.onExecuted(output)
|
||||||
|
this.executionCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper to create execution hooks from a widget-like object
|
||||||
|
*/
|
||||||
|
export function createHooksFromWidget(widget: {
|
||||||
|
beforeQueued?: () => void
|
||||||
|
}, node: {
|
||||||
|
onExecuted?: (output: unknown) => void
|
||||||
|
}): ExecutionHooks {
|
||||||
|
return {
|
||||||
|
beforeQueued: () => widget.beforeQueued?.(),
|
||||||
|
onExecuted: (output) => node.onExecuted?.(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tracks index history during simulation for assertions
|
||||||
|
*/
|
||||||
|
export class IndexTracker {
|
||||||
|
public indexHistory: number[] = []
|
||||||
|
public repeatHistory: number[] = []
|
||||||
|
public pauseHistory: boolean[] = []
|
||||||
|
|
||||||
|
reset() {
|
||||||
|
this.indexHistory = []
|
||||||
|
this.repeatHistory = []
|
||||||
|
this.pauseHistory = []
|
||||||
|
}
|
||||||
|
|
||||||
|
record(config: CyclerConfig) {
|
||||||
|
this.indexHistory.push(config.current_index)
|
||||||
|
this.repeatHistory.push(config.repeat_used)
|
||||||
|
this.pauseHistory.push(config.is_paused)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the sequence of indices that were actually used for execution
|
||||||
|
*/
|
||||||
|
getExecutionIndices(): number[] {
|
||||||
|
return this.indexHistory
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verify that indices cycle correctly through totalCount
|
||||||
|
*/
|
||||||
|
verifyCyclePattern(expectedPattern: number[]): boolean {
|
||||||
|
if (this.indexHistory.length !== expectedPattern.length) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.indexHistory.every((idx, i) => idx === expectedPattern[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,6 @@
|
|||||||
"@/*": ["./src/*"]
|
"@/*": ["./src/*"]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"],
|
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue", "tests/**/*.ts"],
|
||||||
"references": [{ "path": "./tsconfig.node.json" }]
|
"references": [{ "path": "./tsconfig.node.json" }]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,8 +22,10 @@ export default defineConfig({
|
|||||||
rollupOptions: {
|
rollupOptions: {
|
||||||
external: [
|
external: [
|
||||||
'../../../scripts/app.js',
|
'../../../scripts/app.js',
|
||||||
|
'../../../scripts/api.js',
|
||||||
'../loras_widget.js',
|
'../loras_widget.js',
|
||||||
'../autocomplete.js'
|
'../autocomplete.js',
|
||||||
|
'../preview_tooltip.js'
|
||||||
],
|
],
|
||||||
output: {
|
output: {
|
||||||
dir: '../web/comfyui/vue-widgets',
|
dir: '../web/comfyui/vue-widgets',
|
||||||
|
|||||||
25
vue-widgets/vitest.config.ts
Normal file
25
vue-widgets/vitest.config.ts
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import { defineConfig } from 'vitest/config'
|
||||||
|
import vue from '@vitejs/plugin-vue'
|
||||||
|
import { resolve } from 'path'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [vue()],
|
||||||
|
resolve: {
|
||||||
|
alias: {
|
||||||
|
'@': resolve(__dirname, './src')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
test: {
|
||||||
|
globals: true,
|
||||||
|
environment: 'jsdom',
|
||||||
|
setupFiles: ['./tests/setup.ts'],
|
||||||
|
include: ['tests/**/*.test.ts'],
|
||||||
|
coverage: {
|
||||||
|
provider: 'v8',
|
||||||
|
reporter: ['text', 'html', 'json'],
|
||||||
|
reportsDirectory: './coverage',
|
||||||
|
include: ['src/**/*.ts', 'src/**/*.vue'],
|
||||||
|
exclude: ['src/main.ts', 'src/vite-env.d.ts']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
@@ -2,6 +2,7 @@ import { api } from "../../scripts/api.js";
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
|
import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
|
||||||
import { getPromptTagAutocompletePreference, getTagSpaceReplacementPreference } from "./settings.js";
|
import { getPromptTagAutocompletePreference, getTagSpaceReplacementPreference } from "./settings.js";
|
||||||
|
import { showToast } from "./utils.js";
|
||||||
|
|
||||||
// Command definitions for category filtering
|
// Command definitions for category filtering
|
||||||
const TAG_COMMANDS = {
|
const TAG_COMMANDS = {
|
||||||
@@ -15,6 +16,21 @@ const TAG_COMMANDS = {
|
|||||||
'/lore': { categories: [15], label: 'Lore' },
|
'/lore': { categories: [15], label: 'Lore' },
|
||||||
'/emb': { type: 'embedding', label: 'Embeddings' },
|
'/emb': { type: 'embedding', label: 'Embeddings' },
|
||||||
'/embedding': { type: 'embedding', label: 'Embeddings' },
|
'/embedding': { type: 'embedding', label: 'Embeddings' },
|
||||||
|
// Autocomplete toggle commands - only show one based on current state
|
||||||
|
'/ac': {
|
||||||
|
type: 'toggle_setting',
|
||||||
|
settingId: 'loramanager.prompt_tag_autocomplete',
|
||||||
|
value: true,
|
||||||
|
label: 'Autocomplete: ON',
|
||||||
|
condition: () => !getPromptTagAutocompletePreference()
|
||||||
|
},
|
||||||
|
'/noac': {
|
||||||
|
type: 'toggle_setting',
|
||||||
|
settingId: 'loramanager.prompt_tag_autocomplete',
|
||||||
|
value: false,
|
||||||
|
label: 'Autocomplete: OFF',
|
||||||
|
condition: () => getPromptTagAutocompletePreference()
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Category display information
|
// Category display information
|
||||||
@@ -488,6 +504,10 @@ class AutoComplete {
|
|||||||
this.searchType = 'commands';
|
this.searchType = 'commands';
|
||||||
this._showCommandList(commandResult.commandFilter);
|
this._showCommandList(commandResult.commandFilter);
|
||||||
return;
|
return;
|
||||||
|
} else if (commandResult.command?.type === 'toggle_setting') {
|
||||||
|
// Handle toggle setting command (/ac, /noac)
|
||||||
|
this._handleToggleSettingCommand(commandResult.command);
|
||||||
|
return;
|
||||||
} else if (commandResult.command) {
|
} else if (commandResult.command) {
|
||||||
// Command is active, use filtered search
|
// Command is active, use filtered search
|
||||||
this.showingCommands = false;
|
this.showingCommands = false;
|
||||||
@@ -509,7 +529,10 @@ class AutoComplete {
|
|||||||
this.showingCommands = false;
|
this.showingCommands = false;
|
||||||
this.activeCommand = null;
|
this.activeCommand = null;
|
||||||
endpoint = '/lm/custom-words/search?enriched=true';
|
endpoint = '/lm/custom-words/search?enriched=true';
|
||||||
searchTerm = rawSearchTerm;
|
// Extract last space-separated token for search
|
||||||
|
// Tag names don't contain spaces, so we only need the last token
|
||||||
|
// This allows "hello 1gi" to search for "1gi" and find "1girl"
|
||||||
|
searchTerm = this._getLastSpaceToken(rawSearchTerm);
|
||||||
this.searchType = 'custom_words';
|
this.searchType = 'custom_words';
|
||||||
} else {
|
} else {
|
||||||
// No command and setting disabled - no autocomplete for direct typing
|
// No command and setting disabled - no autocomplete for direct typing
|
||||||
@@ -545,6 +568,17 @@ class AutoComplete {
|
|||||||
return lastSegment.trim();
|
return lastSegment.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract the last space-separated token from a search term
|
||||||
|
* Tag names don't contain spaces, so for tag autocomplete we only need the last token
|
||||||
|
* @param {string} term - The full search term (e.g., "hello 1gi")
|
||||||
|
* @returns {string} - The last token (e.g., "1gi"), or the original term if no spaces
|
||||||
|
*/
|
||||||
|
_getLastSpaceToken(term) {
|
||||||
|
const tokens = term.trim().split(/\s+/);
|
||||||
|
return tokens[tokens.length - 1] || term;
|
||||||
|
}
|
||||||
|
|
||||||
async search(term = '', endpoint = null) {
|
async search(term = '', endpoint = null) {
|
||||||
try {
|
try {
|
||||||
this.currentSearchTerm = term;
|
this.currentSearchTerm = term;
|
||||||
@@ -606,9 +640,14 @@ class AutoComplete {
|
|||||||
|
|
||||||
// Check for exact command match
|
// Check for exact command match
|
||||||
if (TAG_COMMANDS[partialCommand]) {
|
if (TAG_COMMANDS[partialCommand]) {
|
||||||
|
const cmd = TAG_COMMANDS[partialCommand];
|
||||||
|
// Filter out toggle commands that don't meet their condition
|
||||||
|
if (cmd.type === 'toggle_setting' && cmd.condition && !cmd.condition()) {
|
||||||
|
return { showCommands: false, command: null, searchTerm: '' };
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
showCommands: false,
|
showCommands: false,
|
||||||
command: TAG_COMMANDS[partialCommand],
|
command: cmd,
|
||||||
searchTerm: '',
|
searchTerm: '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -627,9 +666,14 @@ class AutoComplete {
|
|||||||
const searchPart = trimmed.slice(spaceIndex + 1).trim();
|
const searchPart = trimmed.slice(spaceIndex + 1).trim();
|
||||||
|
|
||||||
if (TAG_COMMANDS[commandPart]) {
|
if (TAG_COMMANDS[commandPart]) {
|
||||||
|
const cmd = TAG_COMMANDS[commandPart];
|
||||||
|
// Filter out toggle commands that don't meet their condition
|
||||||
|
if (cmd.type === 'toggle_setting' && cmd.condition && !cmd.condition()) {
|
||||||
|
return { showCommands: false, command: null, searchTerm: trimmed };
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
showCommands: false,
|
showCommands: false,
|
||||||
command: TAG_COMMANDS[commandPart],
|
command: cmd,
|
||||||
searchTerm: searchPart,
|
searchTerm: searchPart,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -652,6 +696,11 @@ class AutoComplete {
|
|||||||
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
|
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
|
||||||
if (seenLabels.has(info.label)) continue;
|
if (seenLabels.has(info.label)) continue;
|
||||||
|
|
||||||
|
// Filter out toggle commands that don't meet their condition
|
||||||
|
if (info.type === 'toggle_setting' && info.condition) {
|
||||||
|
if (!info.condition()) continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (!filter || cmd.slice(1).startsWith(filterLower)) {
|
if (!filter || cmd.slice(1).startsWith(filterLower)) {
|
||||||
seenLabels.add(info.label);
|
seenLabels.add(info.label);
|
||||||
commands.push({ command: cmd, ...info });
|
commands.push({ command: cmd, ...info });
|
||||||
@@ -1117,7 +1166,16 @@ class AutoComplete {
|
|||||||
|
|
||||||
// Use getSearchTerm to get the current search term before cursor
|
// Use getSearchTerm to get the current search term before cursor
|
||||||
const beforeCursor = currentValue.substring(0, caretPos);
|
const beforeCursor = currentValue.substring(0, caretPos);
|
||||||
const searchTerm = this.getSearchTerm(beforeCursor);
|
const fullSearchTerm = this.getSearchTerm(beforeCursor);
|
||||||
|
|
||||||
|
// For regular tag autocomplete (no command), only replace the last space-separated token
|
||||||
|
// This allows "hello 1gi" + selecting "1girl" to become "hello 1girl, "
|
||||||
|
// Command mode (e.g., "/char miku") should replace the entire command+search
|
||||||
|
let searchTerm = fullSearchTerm;
|
||||||
|
if (this.modelType === 'prompt' && this.searchType === 'custom_words' && !this.activeCommand) {
|
||||||
|
searchTerm = this._getLastSpaceToken(fullSearchTerm);
|
||||||
|
}
|
||||||
|
|
||||||
const searchStartPos = caretPos - searchTerm.length;
|
const searchStartPos = caretPos - searchTerm.length;
|
||||||
|
|
||||||
// Only replace the search term, not everything after the last comma
|
// Only replace the search term, not everything after the last comma
|
||||||
@@ -1175,6 +1233,119 @@ class AutoComplete {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle toggle setting command (/ac, /noac)
|
||||||
|
* @param {Object} command - The toggle command with settingId and value
|
||||||
|
*/
|
||||||
|
async _handleToggleSettingCommand(command) {
|
||||||
|
const { settingId, value } = command;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Use ComfyUI's setting API to update global setting
|
||||||
|
const settingManager = app?.extensionManager?.setting;
|
||||||
|
if (settingManager && typeof settingManager.set === 'function') {
|
||||||
|
await settingManager.set(settingId, value);
|
||||||
|
this._showToggleFeedback(value);
|
||||||
|
this._clearCurrentToken();
|
||||||
|
} else {
|
||||||
|
// Fallback: use legacy settings API
|
||||||
|
const setting = app.ui.settings.settingsById?.[settingId];
|
||||||
|
if (setting) {
|
||||||
|
app.ui.settings.setSettingValue(settingId, value);
|
||||||
|
this._showToggleFeedback(value);
|
||||||
|
this._clearCurrentToken();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('[Lora Manager] Failed to toggle setting:', error);
|
||||||
|
showToast({
|
||||||
|
severity: 'error',
|
||||||
|
summary: 'Error',
|
||||||
|
detail: 'Failed to toggle autocomplete setting',
|
||||||
|
life: 3000
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
this.hide();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Show visual feedback for toggle action using toast
|
||||||
|
* @param {boolean} enabled - New autocomplete state
|
||||||
|
*/
|
||||||
|
_showToggleFeedback(enabled) {
|
||||||
|
showToast({
|
||||||
|
severity: enabled ? 'success' : 'secondary',
|
||||||
|
summary: enabled ? 'Autocomplete Enabled' : 'Autocomplete Disabled',
|
||||||
|
detail: enabled
|
||||||
|
? 'Tag autocomplete is now ON. Type to see suggestions.'
|
||||||
|
: 'Tag autocomplete is now OFF. Use /ac to re-enable.',
|
||||||
|
life: 3000
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear the current command token from input
|
||||||
|
* Preserves leading spaces after delimiters (e.g., "1girl, /ac" -> "1girl, ")
|
||||||
|
*/
|
||||||
|
_clearCurrentToken() {
|
||||||
|
const currentValue = this.inputElement.value;
|
||||||
|
const caretPos = this.inputElement.selectionStart;
|
||||||
|
|
||||||
|
// Find the command text before cursor
|
||||||
|
const beforeCursor = currentValue.substring(0, caretPos);
|
||||||
|
const segments = beforeCursor.split(/[,\>]+/);
|
||||||
|
const lastSegment = segments[segments.length - 1] || '';
|
||||||
|
|
||||||
|
// Find the command start position, preserving leading spaces
|
||||||
|
// lastSegment includes leading spaces (e.g., " /ac"), find where command actually starts
|
||||||
|
const commandMatch = lastSegment.match(/^(\s*)(\/\w+)/);
|
||||||
|
if (commandMatch) {
|
||||||
|
// commandMatch[1] is leading spaces, commandMatch[2] is the command
|
||||||
|
const leadingSpaces = commandMatch[1].length;
|
||||||
|
// Keep the spaces by starting after them
|
||||||
|
const commandStartPos = caretPos - lastSegment.length + leadingSpaces;
|
||||||
|
|
||||||
|
// Skip trailing spaces when deleting
|
||||||
|
let endPos = caretPos;
|
||||||
|
while (endPos < currentValue.length && currentValue[endPos] === ' ') {
|
||||||
|
endPos++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newValue = currentValue.substring(0, commandStartPos) + currentValue.substring(endPos);
|
||||||
|
const newCaretPos = commandStartPos;
|
||||||
|
|
||||||
|
this.inputElement.value = newValue;
|
||||||
|
|
||||||
|
// Trigger input event to notify about the change
|
||||||
|
const event = new Event('input', { bubbles: true });
|
||||||
|
this.inputElement.dispatchEvent(event);
|
||||||
|
|
||||||
|
// Focus back to input and position cursor
|
||||||
|
this.inputElement.focus();
|
||||||
|
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
|
||||||
|
} else {
|
||||||
|
// Fallback: delete the whole last segment (original behavior)
|
||||||
|
const commandStartPos = caretPos - lastSegment.length;
|
||||||
|
|
||||||
|
let endPos = caretPos;
|
||||||
|
while (endPos < currentValue.length && currentValue[endPos] === ' ') {
|
||||||
|
endPos++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newValue = currentValue.substring(0, commandStartPos) + currentValue.substring(endPos);
|
||||||
|
const newCaretPos = commandStartPos;
|
||||||
|
|
||||||
|
this.inputElement.value = newValue;
|
||||||
|
|
||||||
|
const event = new Event('input', { bubbles: true });
|
||||||
|
this.inputElement.dispatchEvent(event);
|
||||||
|
|
||||||
|
this.inputElement.focus();
|
||||||
|
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
destroy() {
|
destroy() {
|
||||||
if (this.debounceTimer) {
|
if (this.debounceTimer) {
|
||||||
clearTimeout(this.debounceTimer);
|
clearTimeout(this.debounceTimer);
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
/* Shared styling for the LoRA Manager frontend widgets */
|
/* Shared styling for the LoRA Manager frontend widgets */
|
||||||
.lm-tooltip {
|
.lm-tooltip {
|
||||||
position: fixed;
|
position: fixed;
|
||||||
z-index: 9999;
|
z-index: 10001;
|
||||||
background: rgba(0, 0, 0, 0.85);
|
background: rgba(0, 0, 0, 0.85);
|
||||||
border-radius: 6px;
|
border-radius: 6px;
|
||||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user